# Reduction Lowering ## Current blocker 1) ```python= X2YDim = ... a: Field[[Y], ...] b: Field[[X, X2YDim], ...] @fieldop def foo(a, b): nbh_sum(a(X2Y) + b, axis=X2YDim) ``` 1b) make a tansform pass to create ```python= X2YDim = ... a: Field[[Y], ...] b: Field[[X, X2YDim], ...] @fieldop def foo(a, b): a_shifted = a(X2Y) nbh_sum(a_shifted + b, axis=X2YDim) ``` 1c) shifts in reduce are syntactical sugar and we can do that later, write this by hand ```python= X2YDim = ... a: Field[[Y], ...] b: Field[[X, X2YDim], ...] @fieldop def foo(a, b): a_tmp = a(X2Y) nbh_sum(a_tmp + b, axis=X2YDim) 2) ```python= reduce( (lambda accum, a_, b_: plus(accum, plus(a, b))), accum_init )( shift(X2Y)(a), b ) ``` Shifts have to be collected and moved into arguments 3) ```python= reduce( (lambda accum, expr: plus(accum, expr)), accum_init )( lift(lambda a, b: plus(deref(a), deref(b))(shift(X2Y)(a), b) lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(a, b) ) shift(0)(lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(a, b)) (lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(shift(0)(a), shift(0)(b))) lift(lambda a, b: plus(deref(shift(X2Y)(shift(0)(a))), deref(b))(a, b) shift(0)(lift(lambda a, b: plus(deref(a), deref(b))(shift(X2Y)(a), b)) (lift(lambda a, b: plus(deref(a), deref(b))(shift(0)(shift(X2Y)(a)), b)) ``` 4) reduce expression direct from current lifting ```python= lift(lambda a, b: plus(deref(lift(lambda a_: shift(X2Y)(a_)(a)), deref(b))(a, b) lift(lambda a, b: plus(deref(shift(X2Y)(a)), deref(b))(a, b) visit_Name("a") -> "a" visit_Call("a(X2Y)") -> shift(X2Y)(a) visit_Name("b") -> "b" # current visit_BinaryOp("plus(a(X2Y), b)") -> lift(lambda...: plus(deref(visit("a(X2Y)")), deref(visit("b"))))(a, b) # proposed: keep shifts outside visit_BinaryOp("plus(a(X2Y), b)") -> lift(lambda a, b: plus(deref(a), deref(b)))(visit("a(X2Y)"), visit("b")) ``` ```python def sten(inp): return deref(shift(A)(inp)) def foo(inp): return deref(shift(B)(lift(sten)(inp))) def foo(inp): return sten(shift(B)(inp)) def foo(inp): return deref(shift(A)(shift(B)(inp))) ``` ``` SSA pass: a, b = ... a = a + 1 to a = ...[0] b = ...[1] a2 = a + 1 ``` ## Conclusion * For now: See how far we get with 3) / 4) and fall back on 1c) if we don't get far enough. * Future: Probably do it cleanly in a separate project and answer all the questions about partial shifts and shift ordering.