# Reduction Lowering
## Current blocker
1)
```python=
X2YDim = ...
a: Field[[Y], ...]
b: Field[[X, X2YDim], ...]
@fieldop
def foo(a, b):
nbh_sum(a(X2Y) + b, axis=X2YDim)
```
1b) make a tansform pass to create
```python=
X2YDim = ...
a: Field[[Y], ...]
b: Field[[X, X2YDim], ...]
@fieldop
def foo(a, b):
a_shifted = a(X2Y)
nbh_sum(a_shifted + b, axis=X2YDim)
```
1c) shifts in reduce are syntactical sugar and we can do that later, write this by hand
```python=
X2YDim = ...
a: Field[[Y], ...]
b: Field[[X, X2YDim], ...]
@fieldop
def foo(a, b):
a_tmp = a(X2Y)
nbh_sum(a_tmp + b, axis=X2YDim)
2)
```python=
reduce(
(lambda accum, a_, b_: plus(accum, plus(a, b))),
accum_init
)(
shift(X2Y)(a),
b
)
```
Shifts have to be collected and moved into arguments
3)
```python=
reduce(
(lambda accum, expr: plus(accum, expr)),
accum_init
)(
lift(lambda a, b: plus(deref(a), deref(b))(shift(X2Y)(a), b)
lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(a, b)
)
shift(0)(lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(a, b))
(lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(shift(0)(a), shift(0)(b)))
lift(lambda a, b: plus(deref(shift(X2Y)(shift(0)(a))), deref(b))(a, b)
shift(0)(lift(lambda a, b: plus(deref(a), deref(b))(shift(X2Y)(a), b))
(lift(lambda a, b: plus(deref(a), deref(b))(shift(0)(shift(X2Y)(a)), b))
```
4) reduce expression direct from current lifting
```python=
lift(lambda a, b: plus(deref(lift(lambda a_: shift(X2Y)(a_)(a)), deref(b))(a, b)
lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(a, b)
visit_Name("a") -> "a"
visit_Call("a(X2Y)") -> shift(X2Y)(a)
visit_Name("b") -> "b"
# current
visit_BinaryOp("plus(a(X2Y), b)") -> lift(lambda...: plus(deref(visit("a(X2Y)")), deref(visit("b"))))(a, b)
# proposed: keep shifts outside
visit_BinaryOp("plus(a(X2Y), b)") -> lift(lambda a, b: plus(deref(a), deref(b)))(visit("a(X2Y)"), visit("b"))
```
```python
def sten(inp):
return deref(shift(A)(inp))
def foo(inp):
return deref(shift(B)(lift(sten)(inp)))
def foo(inp):
return sten(shift(B)(inp))
def foo(inp):
return deref(shift(A)(shift(B)(inp)))
```
```
SSA pass:
a, b = ...
a = a + 1
to
a = ...[0]
b = ...[1]
a2 = a + 1
```
## Conclusion
* For now: See how far we get with 3) / 4) and fall back on 1c) if we don't get far enough.
* Future: Probably do it cleanly in a separate project and answer all the questions about partial shifts and shift ordering.