# Attention Rewriter
## Current Workflow
```python=
for computation in module:
for instruction in computation:
found = find_fwd(instruction)
if found:
fwd = fuse_fwd()
if training:
find_bwd(fwd)
```
fwd pattern matching
```python=
def find_fwd(instruction):
# Use current instruction as bmm2 and go bottom up through define chain
if find_bmm1_bmm2(instruction):
# BMM1 - BMM2
return true
softmax = find_softmax_dropout_bmm2(instruction)
if not softmax:
return false
if find_scale_bias(softmax):
# BMM1 - [scale] - [bias] - softmax - [dropout] - BMM2
return true
if find_mask(softmax):
# BMM1 - [scale] - [bias] - mask - softmax - [dropout] - BMM2
return true
```
bwd pattern matching
```python=
def find_bwd(fwd):
# find bmm1_grad1 bmm1_grad2 bmm2_grad1 bmm2_grad2
bmm_arr = find_4_bwd_bmm(fwd)
if not bmm_arr:
return false
if find_bmm1_bmm2_bwd(bmm_arr):
# BMM1 - BMM2 Backward
return true
bmm1_grad1 = bmm_arr[0]
if find_scale_bias_mask_softmax_dropout(bmm1_grad1):
# Using bmm1_grad1 to find optional bwd op
# BMM1 - [scale] - [bias] - [mask] - softmax - [dropout] - BMM2 Backward
return true
```
find 4 bmm2 using fwd
```python=
def find_4_bwd_bmm(fwd):
P = fwd.outputs[1]
bmm2_grad1 = P.Users[0]
V = fwd.inputs[2]
bmm2_grad2 = V.users[1]
Q = fwd.inputs[0]
bmm1_grad1 = Q.users[1]
# bmm1_grad2 shares same input with bmm1_grad1
dS = bmm1_grad1.inputs[0]
bmm1_grad2 = dS.users[1]
```

### Issues
* Need P, V, Q from fwd to find 4 bmm2 in the bwd
* Will fail for current pattern match as well if there is loop surrounding fwd and bwd
* some bwd can be pattern matched as fwd
* BMM1 - [optional] - BMM2 and BMM1 - [optional] - BMM2GRAD2
* addressed by making sure all bmm from fwd is pattern matched once
* BMM1 - BMM2 and BMM2GRAD1 - BMM1GRAD1
* `IsFirstFwdMatmul` will check if a gemm is fwd gemm by checking if there is any operand coming from fwd call.
* Will fail for current pattern match as well if there is loop surrounding fwd and bwd
* Currently we require softmax must be in the bwd so this is not an issue
## New Workflow
```python=
for computation in module:
for instruction in computation:
find_fwd(instruction)
for instruction in computation:
find_bwd(instruction)
```
`find_fwd` is the same
new `find-bwd` is defined as
```python=
def find_bwd(instruction):
# Use current instruction as bmm1_grad1 and go bottom up through define chain
# bmm1_grad2 shares same input with bmm1_grad1
bmm1_grad1 = instruction
dS = bmm1_grad1.inputs[0]
bmm1_grad2 = dS.users[1]
bmm2_grad1 = find_scale_bias_mask_softmax_dropout(bmm1_grad1)
if bmm2_grad1:
# Using bmm1_grad1 to find optional bwd op
# BMM1 - [scale] - [bias] - [mask] - softmax - [dropout] - BMM2 Backward
dO = bmm1_grad1.inputs[0]
bmm2_grad2 = dO.users[1]
return true
return false
```
### Changes
* Instead of finding 4 bmms in the bwd first, we treat current instruction as bmm1_grad1 and go bottom up. This is the same logic as the original `find_bwd`.
* We will find 3 bmms after `find_scale_bias_mask_softmax_dropout`. Need an extra logic to find `bmm2_grad2`. Since we already have `bmm2_grad1`, we can use the shared input.
### Summary
* change main loop to find all fwd first and then bwd
* remove find_4_bwd_bmm
* add find_bmm2_grad1
### Issues
* Need fwd_output from fwd for flash attention