###### tags: `frontend` # Frontend proposals Let's sketch some examples for the different execution models and possibly beautified frontends List of proposals: **Raw iterator view**: as implemented in prototype. Rough idea: Specification of the computations using the parallel model constructs. Benefits: - Full flexibility to express any computation and pattern supported by the model, without extra work in designing a DSL. - It works as a IR serialization mechanism Problems: - Not at all user friendly - Verbose **Classic GT4Py (or "simplified" iterator view)**: Implicit `lift` and `deref` for iterator view Rough idea: - `deref` is inserted when a value is expected, but an iterator is received, e.g. when an iterator is involved in a builtin that works on value, e.g. arithmetic operations, if, etc. - `lift` is inserted if an iterator is expected, but a value is received, e.g. when a shift is applied, we need to lift the producer of that object - "beautified" shift with current GT4Py look-and-feel - Python `if` is translated to iterator view ternary Note that because of the implicit `lift` and `deref`, we can express less than the iterator view allows. Benefits: - It is usually conceptually simpler to write and reason about code written in a local perspective (operations to perform for each point in the domain) Problems: - Doesn't allow direct (embedded) execution - Weird objects that are iterators and values at the same time **Field view**: use Python array notation (e.g. [NumPy](https://numpy.org/doc/stable/user/quickstart.html#the-basics), [Python array API standard](https://data-apis.org/array-api/latest/), ...) Rough idea: Use classic Python array notation to describe operations between fields. This works very well for element-wise operations but not so well for more complex operations like conditionals. The syntax for shifting/remapping fields could be defined using the `[]` indexing (`__getitem__()`) operator but it needs some thoughts since it might easily get verbose when dealing with boundaries or undefined neighbors. Benefits: - Direct execution in Python can be implemented very efficiently without too much effort by relying on existing high-performance array libraries, both in CPU (e.g. NumPy) and accelerators (e.g. CuPy, JAX, PyTorch, ...) - Very familiar syntax for scientists with a Python background, since it matches the syntax of most of libraries within the scientific Python ecosystem. Problems: - More verbose syntax for non element-wise operations like conditionals, shifting/remapping, ... - Conceptually harder to think about typical point operations in stencils at a global field level. ## Horizontal diffusion ### Raw iterator view ```python= I = offset("I") J = offset("J") @fundef def laplacian(inp): return -4.0 * deref(inp) + ( deref(shift(I, 1)(inp)) + deref(shift(I, -1)(inp)) + deref(shift(J, 1)(inp)) + deref(shift(J, -1)(inp)) ) @fundef def flux(d): def flux_impl(inp): lap = lift(laplacian)(inp) flux = deref(lap) - deref(shift(d, 1)(lap)) return if_(flux * (deref(shift(d, 1)(inp)) - deref(inp)) > 0.0, 0.0, flux) return flux_impl @fundef def hdiff_sten(inp, coeff): flx = lift(flux(I))(inp) fly = lift(flux(J))(inp) return deref(inp) - ( deref(coeff) * ( deref(flx) - deref(shift(I, -1)(flx)) + deref(fly) - deref(shift(J, -1)(fly)) ) ) @fendef def hdiff(inp, coeff, out, x, y): closure( domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), hdiff_sten, [out], [inp, coeff], ) ``` ### Classic GT4Py ```python= I = offset("I") J = offset("J") @fundef def laplacian(inp): return -4.0 * inp + (inp[I+1]+inp[I-1]+inp[J+1]+inp[J-1]) @fundef def flux(d): def flux_impl(inp): lap = laplacian(inp) flux = lap - lap[d+1] if flux * (inp[d-1] - inp) > 0.0: return 0.0 else: return flux return flux_impl @fundef def hdiff_sten(inp, coeff): flx = flux(I)(inp) fly = flux(J)(inp) return inp - (coeff*(flx - flx[I-1] + fly - fly[J-1])) @fendef def hdiff(inp, coeff, out, x, y): closure( domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), hdiff_sten, [out], [inp, coeff], ) ``` ### Field view **WIP** ```python= DType = TypeVar("DType") # Dimensions I = CartesianDimension("I") J = CartesianDimension("J") @operator def laplacian(inp: Field[DType]): return -4.0 * inp + (inp(I + 1) + inp(I - 1) + inp(J + 1) + inp(J - 1)) @operator def flux(inp: Field[DType], D: Dimension): lap = laplacian(inp) flux = lap - lap(D + 1) return gt.where(flux * (inp(D - 1) - inp) > 0.0, 0.0, flux) @operator def hdiff_stencil(inp: Field[DType], coeff: Field[DType]) -> Field[DType]: flx = flux(inp, I) fly = flux(inp, J) return inp - (coeff * (flx - flx(I - 1) + fly - fly(J - 1))) @program def hdiff( inp: Field[[I, J], DType], coeff: Field[[I, J], DType], out: Storage[[I, J], DType], domain: Mapping[Dimension, Tuple[int, int]] ): hdiff_stencil.apply(domain, [inp, coeff], [out]) ``` ## Vertical advection ## FVM nabla ### Raw iterator view ```python= V2E = offset("V2E") E2V = offset("E2V") @fundef def compute_zavgS(pp, S_M): # zavg = 0.5 * (deref(shift(E2V, 0)(pp)) + deref(shift(E2V, 1)(pp))) zavg = 0.5 * library.sum()(shift(E2V)(pp)) return deref(S_M) * zavg @fundef def compute_pnabla(pp, S_M, sign, vol): zavgS = lift(compute_zavgS)(pp, S_M) # pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign) pnabla_M = library.sum(lambda a, b: a * b)(shift(V2E)(zavgS), sign) return pnabla_M / deref(vol) @fundef def pnabla(pp, S_MXX, S_MYY, sign, vol): return compute_pnabla(pp, S_MXX, sign, vol), compute_pnabla(pp, S_MYY, sign, vol) @fendef def nabla( n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, sign, vol, ): closure( domain(named_range(Vertex, 0, n_nodes)), pnabla, [out_MXX, out_MYY], [pp, S_MXX, S_MYY, sign, vol], ) ``` ### Classic GT4Py ```python= V2E = offset("V2E") E2V = offset("E2V") @fundef def compute_zavgS(pp, S_M): zavg = 0.5 * sum(pp[v] for v in E2V) return S_M * zavg @fundef def compute_pnabla(pp, S_M, sign, vol): zavgS = compute_zavgS(pp, S_M) # pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign) pnabla_M = sum(zavgS[e] * sign[e] for e in V2E) return pnabla_M / vol @fundef def pnabla(pp, S_MXX, S_MYY, sign, vol): return compute_pnabla(pp, S_MXX, sign, vol), compute_pnabla(pp, S_MYY, sign, vol) @fendef def nabla( n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, sign, vol, ): closure( domain(named_range(Vertex, 0, n_nodes)), pnabla, [out_MXX, out_MYY], [pp, S_MXX, S_MYY, sign, vol], ) ``` ### Field view **WIP** ```python V2E = offset("V2E") E2V = offset("E2V") @operator def compute_zavgS(pp, S_M): zavg = 0.5 * sum(pp[v] for v in E2V) return S_M * zavg @operator def compute_pnabla(pp, S_M, sign, vol): zavgS = compute_zavgS(pp, S_M) # pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign) pnabla_M = sum(zavgS[e] * sign[e] for e in V2E) return pnabla_M / vol @operator def pnabla(pp, S_MXX, S_MYY, sign, vol): return compute_pnabla(pp, S_MXX, sign, vol), compute_pnabla(pp, S_MYY, sign, vol) @program def nabla( n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, sign, vol, ): closure( domain(named_range(Vertex, 0, n_nodes)), pnabla, [out_MXX, out_MYY], [pp, S_MXX, S_MYY, sign, vol], ) ``` ## Horizontal indirection from ICON ### Raw iterator view #### Alternative A Note: compute_shift cannot be materialized (because return value is a function) ```python= I = offset("I") @fundef def compute_shift(cond): return if_(deref(cond) < 0, shift(I, -1), shift(I, 1)) @fundef def foo(inp, cond): return deref(compute_shift(cond)(inp)) @fendef def fencil(size, inp, cond, out): closure(domain(named_range(IDim, 0, size)), foo, [out], [inp, cond]) ``` #### Alternative B Introducing `dyn_shift(runtime_int, (lower_bound, upper_bound))` ```python= I = offset("I") @fundef def compute_shift(cond): return if_(deref(cond) < 0, -1, 1) @fundef def foo(inp, cond): return deref(dyn_shift(compute_shift(cond), (-1,1))(shift(I)(inp))) @fendef def fencil(size, inp, cond, out): closure(domain(named_range(IDim, 0, size)), foo, [out], [inp, cond]) ``` ## Maybe: FV3 remapping