###### tags: `frontend`
# Frontend proposals
Let's sketch some examples for the different execution models and possibly beautified frontends
List of proposals:
**Raw iterator view**: as implemented in prototype.
Rough idea:
Specification of the computations using the parallel model constructs.
Benefits:
- Full flexibility to express any computation and pattern supported by the model, without extra work in designing a DSL.
- It works as a IR serialization mechanism
Problems:
- Not at all user friendly
- Verbose
**Classic GT4Py (or "simplified" iterator view)**: Implicit `lift` and `deref` for iterator view
Rough idea:
- `deref` is inserted when a value is expected, but an iterator is received, e.g. when an iterator is involved in a builtin that works on value, e.g. arithmetic operations, if, etc.
- `lift` is inserted if an iterator is expected, but a value is received, e.g. when a shift is applied, we need to lift the producer of that object
- "beautified" shift with current GT4Py look-and-feel
- Python `if` is translated to iterator view ternary
Note that because of the implicit `lift` and `deref`, we can express less than the iterator view allows.
Benefits:
- It is usually conceptually simpler to write and reason about code written in a local perspective (operations to perform for each point in the domain)
Problems:
- Doesn't allow direct (embedded) execution
- Weird objects that are iterators and values at the same time
**Field view**: use Python array notation (e.g. [NumPy](https://numpy.org/doc/stable/user/quickstart.html#the-basics), [Python array API standard](https://data-apis.org/array-api/latest/), ...)
Rough idea:
Use classic Python array notation to describe operations between fields. This works very well for element-wise operations but not so well for more complex operations like conditionals. The syntax for shifting/remapping fields could be defined using the `[]` indexing (`__getitem__()`) operator but it needs some thoughts since it might easily get verbose when dealing with boundaries or undefined neighbors.
Benefits:
- Direct execution in Python can be implemented very efficiently without too much effort by relying on existing high-performance array libraries, both in CPU (e.g. NumPy) and accelerators (e.g. CuPy, JAX, PyTorch, ...)
- Very familiar syntax for scientists with a Python background, since it matches the syntax of most of libraries within the scientific Python ecosystem.
Problems:
- More verbose syntax for non element-wise operations like conditionals, shifting/remapping, ...
- Conceptually harder to think about typical point operations in stencils at a global field level.
## Horizontal diffusion
### Raw iterator view
```python=
I = offset("I")
J = offset("J")
@fundef
def laplacian(inp):
return -4.0 * deref(inp) + (
deref(shift(I, 1)(inp))
+ deref(shift(I, -1)(inp))
+ deref(shift(J, 1)(inp))
+ deref(shift(J, -1)(inp))
)
@fundef
def flux(d):
def flux_impl(inp):
lap = lift(laplacian)(inp)
flux = deref(lap) - deref(shift(d, 1)(lap))
return if_(flux * (deref(shift(d, 1)(inp)) - deref(inp)) > 0.0, 0.0, flux)
return flux_impl
@fundef
def hdiff_sten(inp, coeff):
flx = lift(flux(I))(inp)
fly = lift(flux(J))(inp)
return deref(inp) - (
deref(coeff)
* (
deref(flx)
- deref(shift(I, -1)(flx))
+ deref(fly)
- deref(shift(J, -1)(fly))
)
)
@fendef
def hdiff(inp, coeff, out, x, y):
closure(
domain(named_range(IDim, 0, x), named_range(JDim, 0, y)),
hdiff_sten,
[out],
[inp, coeff],
)
```
### Classic GT4Py
```python=
I = offset("I")
J = offset("J")
@fundef
def laplacian(inp):
return -4.0 * inp + (inp[I+1]+inp[I-1]+inp[J+1]+inp[J-1])
@fundef
def flux(d):
def flux_impl(inp):
lap = laplacian(inp)
flux = lap - lap[d+1]
if flux * (inp[d-1] - inp) > 0.0:
return 0.0
else:
return flux
return flux_impl
@fundef
def hdiff_sten(inp, coeff):
flx = flux(I)(inp)
fly = flux(J)(inp)
return inp - (coeff*(flx - flx[I-1] + fly - fly[J-1]))
@fendef
def hdiff(inp, coeff, out, x, y):
closure(
domain(named_range(IDim, 0, x), named_range(JDim, 0, y)),
hdiff_sten,
[out],
[inp, coeff],
)
```
### Field view
**WIP**
```python=
DType = TypeVar("DType")
# Dimensions
I = CartesianDimension("I")
J = CartesianDimension("J")
@operator
def laplacian(inp: Field[DType]):
return -4.0 * inp + (inp(I + 1) + inp(I - 1) + inp(J + 1) + inp(J - 1))
@operator
def flux(inp: Field[DType], D: Dimension):
lap = laplacian(inp)
flux = lap - lap(D + 1)
return gt.where(flux * (inp(D - 1) - inp) > 0.0, 0.0, flux)
@operator
def hdiff_stencil(inp: Field[DType], coeff: Field[DType]) -> Field[DType]:
flx = flux(inp, I)
fly = flux(inp, J)
return inp - (coeff * (flx - flx(I - 1) + fly - fly(J - 1)))
@program
def hdiff(
inp: Field[[I, J], DType],
coeff: Field[[I, J], DType],
out: Storage[[I, J], DType],
domain: Mapping[Dimension, Tuple[int, int]]
):
hdiff_stencil.apply(domain, [inp, coeff], [out])
```
## Vertical advection
## FVM nabla
### Raw iterator view
```python=
V2E = offset("V2E")
E2V = offset("E2V")
@fundef
def compute_zavgS(pp, S_M):
# zavg = 0.5 * (deref(shift(E2V, 0)(pp)) + deref(shift(E2V, 1)(pp)))
zavg = 0.5 * library.sum()(shift(E2V)(pp))
return deref(S_M) * zavg
@fundef
def compute_pnabla(pp, S_M, sign, vol):
zavgS = lift(compute_zavgS)(pp, S_M)
# pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign)
pnabla_M = library.sum(lambda a, b: a * b)(shift(V2E)(zavgS), sign)
return pnabla_M / deref(vol)
@fundef
def pnabla(pp, S_MXX, S_MYY, sign, vol):
return compute_pnabla(pp, S_MXX, sign, vol), compute_pnabla(pp, S_MYY, sign, vol)
@fendef
def nabla(
n_nodes,
out_MXX,
out_MYY,
pp,
S_MXX,
S_MYY,
sign,
vol,
):
closure(
domain(named_range(Vertex, 0, n_nodes)),
pnabla,
[out_MXX, out_MYY],
[pp, S_MXX, S_MYY, sign, vol],
)
```
### Classic GT4Py
```python=
V2E = offset("V2E")
E2V = offset("E2V")
@fundef
def compute_zavgS(pp, S_M):
zavg = 0.5 * sum(pp[v] for v in E2V)
return S_M * zavg
@fundef
def compute_pnabla(pp, S_M, sign, vol):
zavgS = compute_zavgS(pp, S_M)
# pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign)
pnabla_M = sum(zavgS[e] * sign[e] for e in V2E)
return pnabla_M / vol
@fundef
def pnabla(pp, S_MXX, S_MYY, sign, vol):
return compute_pnabla(pp, S_MXX, sign, vol), compute_pnabla(pp, S_MYY, sign, vol)
@fendef
def nabla(
n_nodes,
out_MXX,
out_MYY,
pp,
S_MXX,
S_MYY,
sign,
vol,
):
closure(
domain(named_range(Vertex, 0, n_nodes)),
pnabla,
[out_MXX, out_MYY],
[pp, S_MXX, S_MYY, sign, vol],
)
```
### Field view
**WIP**
```python
V2E = offset("V2E")
E2V = offset("E2V")
@operator
def compute_zavgS(pp, S_M):
zavg = 0.5 * sum(pp[v] for v in E2V)
return S_M * zavg
@operator
def compute_pnabla(pp, S_M, sign, vol):
zavgS = compute_zavgS(pp, S_M)
# pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign)
pnabla_M = sum(zavgS[e] * sign[e] for e in V2E)
return pnabla_M / vol
@operator
def pnabla(pp, S_MXX, S_MYY, sign, vol):
return compute_pnabla(pp, S_MXX, sign, vol), compute_pnabla(pp, S_MYY, sign, vol)
@program
def nabla(
n_nodes,
out_MXX,
out_MYY,
pp,
S_MXX,
S_MYY,
sign,
vol,
):
closure(
domain(named_range(Vertex, 0, n_nodes)),
pnabla,
[out_MXX, out_MYY],
[pp, S_MXX, S_MYY, sign, vol],
)
```
## Horizontal indirection from ICON
### Raw iterator view
#### Alternative A
Note: compute_shift cannot be materialized (because return value is a function)
```python=
I = offset("I")
@fundef
def compute_shift(cond):
return if_(deref(cond) < 0, shift(I, -1), shift(I, 1))
@fundef
def foo(inp, cond):
return deref(compute_shift(cond)(inp))
@fendef
def fencil(size, inp, cond, out):
closure(domain(named_range(IDim, 0, size)), foo, [out], [inp, cond])
```
#### Alternative B
Introducing `dyn_shift(runtime_int, (lower_bound, upper_bound))`
```python=
I = offset("I")
@fundef
def compute_shift(cond):
return if_(deref(cond) < 0, -1, 1)
@fundef
def foo(inp, cond):
return deref(dyn_shift(compute_shift(cond), (-1,1))(shift(I)(inp)))
@fendef
def fencil(size, inp, cond, out):
closure(domain(named_range(IDim, 0, size)), foo, [out], [inp, cond])
```
## Maybe: FV3 remapping