# Attention Rewriter ## Current Workflow ```python= for computation in module: for instruction in computation: found = find_fwd(instruction) if found: fwd = fuse_fwd() if training: find_bwd(fwd) ``` fwd pattern matching ```python= def find_fwd(instruction): # Use current instruction as bmm2 and go bottom up through define chain if find_bmm1_bmm2(instruction): # BMM1 - BMM2 return true softmax = find_softmax_dropout_bmm2(instruction) if not softmax: return false if find_scale_bias(softmax): # BMM1 - [scale] - [bias] - softmax - [dropout] - BMM2 return true if find_mask(softmax): # BMM1 - [scale] - [bias] - mask - softmax - [dropout] - BMM2 return true ``` bwd pattern matching ```python= def find_bwd(fwd): # find bmm1_grad1 bmm1_grad2 bmm2_grad1 bmm2_grad2 bmm_arr = find_4_bwd_bmm(fwd) if not bmm_arr: return false if find_bmm1_bmm2_bwd(bmm_arr): # BMM1 - BMM2 Backward return true bmm1_grad1 = bmm_arr[0] if find_scale_bias_mask_softmax_dropout(bmm1_grad1): # Using bmm1_grad1 to find optional bwd op # BMM1 - [scale] - [bias] - [mask] - softmax - [dropout] - BMM2 Backward return true ``` find 4 bmm2 using fwd ```python= def find_4_bwd_bmm(fwd): P = fwd.outputs[1] bmm2_grad1 = P.Users[0] V = fwd.inputs[2] bmm2_grad2 = V.users[1] Q = fwd.inputs[0] bmm1_grad1 = Q.users[1] # bmm1_grad2 shares same input with bmm1_grad1 dS = bmm1_grad1.inputs[0] bmm1_grad2 = dS.users[1] ``` ![](https://hackmd.io/_uploads/BJ_5_sBMT.png) ### Issues * Need P, V, Q from fwd to find 4 bmm2 in the bwd * Will fail for current pattern match as well if there is loop surrounding fwd and bwd * some bwd can be pattern matched as fwd * BMM1 - [optional] - BMM2 and BMM1 - [optional] - BMM2GRAD2 * addressed by making sure all bmm from fwd is pattern matched once * BMM1 - BMM2 and BMM2GRAD1 - BMM1GRAD1 * `IsFirstFwdMatmul` will check if a gemm is fwd gemm by checking if there is any operand coming from fwd call. * Will fail for current pattern match as well if there is loop surrounding fwd and bwd * Currently we require softmax must be in the bwd so this is not an issue ## New Workflow ```python= for computation in module: for instruction in computation: find_fwd(instruction) for instruction in computation: find_bwd(instruction) ``` `find_fwd` is the same new `find-bwd` is defined as ```python= def find_bwd(instruction): # Use current instruction as bmm1_grad1 and go bottom up through define chain # bmm1_grad2 shares same input with bmm1_grad1 bmm1_grad1 = instruction dS = bmm1_grad1.inputs[0] bmm1_grad2 = dS.users[1] bmm2_grad1 = find_scale_bias_mask_softmax_dropout(bmm1_grad1) if bmm2_grad1: # Using bmm1_grad1 to find optional bwd op # BMM1 - [scale] - [bias] - [mask] - softmax - [dropout] - BMM2 Backward dO = bmm1_grad1.inputs[0] bmm2_grad2 = dO.users[1] return true return false ``` ### Changes * Instead of finding 4 bmms in the bwd first, we treat current instruction as bmm1_grad1 and go bottom up. This is the same logic as the original `find_bwd`. * We will find 3 bmms after `find_scale_bias_mask_softmax_dropout`. Need an extra logic to find `bmm2_grad2`. Since we already have `bmm2_grad1`, we can use the shared input. ### Summary * change main loop to find all fwd first and then bwd * remove find_4_bwd_bmm * add find_bmm2_grad1 ### Issues * Need fwd_output from fwd for flash attention