###### tags: `frontend`, `shaping`
GT4Py frontend design (Main)
============================
[TOC]
# Reference documents
## Frontend
- [Frontend proposals](https://hackmd.io/IgPQAkoTSEyGlEfY1H4GRw)
- [Frontend proposal: Field view with local operations](https://hackmd.io/NCbpdFt7SfudVQOvF5p5yA)
- [A Frontend Specification for the Functional GT4Py Model (Outdated)](https://hackmd.io/6t7aobS2QWK1I92uwfI79w?view)
## Functional model
- [New model project](https://hackmd.io/rT-Uwd2ASq6-h5rdlovZ_g)
### Iterator view
- [Iterator view model description](https://github.com/GridTools/concepts/blob/master/Iterator-View.md)
- [Iterator view model description (legacy)](https://hackmd.io/@gridtools/B1Yh6B8cd)
- [Iterator view presentation](https://docs.google.com/presentation/d/1QjGZ3V8YIeOESTnp7bwF1AAate5v47Tc4IoVnqyDH1s/edit?usp=sharing )
### Field view
- [Field view (without sizes)](https://hackmd.io/@havogt/ryaxJKX5_)
- [Mixed / Till Model](https://hackmd.io/8WXjrwEGQaabUnqMFtvg5A?both)
- [WIP Field view design notes](https://hackmd.io/QqcxL37aSTaIw28Wyyv0rg)
### Cartesian accessor view ("First iteration")
- [Cartesian accessor view (first "new model")](https://hackmd.io/@gridtools/SkEKFlGEO)
- [Repository with toy implementations](https://github.com/fthaler/gt4py_new_model): `main` contains the accessor view, then several branches with slightly different models for unstructured.
## Classic GT4Py model
- [Problems with the old models](https://hackmd.io/@havogt/rk_ELMUq_)
## Scientific Python ecosystem
- [Scientific Python Ecosystem Coordination (SPEC)](https://scientific-python.org/)
- [Python array API standard: A common API for array and tensor Python libraries](https://data-apis.org/array-api/latest/#)
# Meetings
## 2021-10-25
### Brainstorm
- Embedded field execution
- We have embedded iterator execution (https://github.com/GridTools/gt4py/pull/482)
- Plain iterator tracing (see PR above)
- Trace or parse?
- Beautified iterator frontend
- Lifted iterator / field view frontend
- Numba
In which order do we implement frontend styles?
Do we target embedded or IR generation first?
Which technology for IR generation: tracing or parsing?
### Conclusion:
- Implement field-view style frontend using parsing with all the features (including `scan()`, etc.)
- Use current gt4py repo (`iterator_prototype`? branch)
- Use Python 3.10
- Take a look to previous unstructured prototype from Till for Python to first IR (https://github.com/GridTools/gt4py/blob/master/src/gtc_unstructured/frontend/py_to_gtscript.py)
## 2021-08-30
### Agenda
1. Review conclusions about conditionals from previous meeting
2. Address model questions
3. Review and discuss the approaches of two basic proposals (local and positional)
4. Introduction and discussion of `shifts`
### Minutes
- Enrique:
+ write examples to explain why lift should be explicit
+ write a clear syntax proposal for shifting in cartesian and unstructured
## 2021-08-26 Discussion meeting
### Agenda
1. Quick review of the new model understanding by looking at the mathematical definitions of the main concepts
2. Continue discussion of the semi-formal definition of the proposed interfaces [A Frontend Specification for the Functional GT4Py Model](https://hackmd.io/@gridtools/HJxTggueY/edit) specially focusing on the following aspects:
- Mapping of `if`s/conditionals to ternary operators (restrictions and algorithm)
- Evaluate the proposed algorithm in potentially confusing cases (like the new ICON examples provided by Ben)
- (Maybe: review ideas for the injection/access of externals symbols within the DSL)
### Minutes
- Till + EGP: Remove integer requirement for indices
- Till cleanup conditionals
- Ben: Add sentences about _symbol definedness_ algorithm
## 2021-08-24 Discussion meeting
### Notes
- Neighbor tables contents in ICON are only defined at setup (hence) runtime. Only rank / dimension of neighbor tables is known.
### Agenda
1. Present documents with semi-formal definition of the proposed interfaces: [A Frontend Specification for the Functional GT4Py Model](https://hackmd.io/@gridtools/HJxTggueY/edit)
- Discuss general structure/DSL workflow
- Discuss general approach/interface to introduce external symbols in the DSL (imports, functions or implicit?)
- Collect list of open questions for both syntax and concepts
- Collect list of issues for the model IR spec document
### Minutes
Action points:
- Till: Fix externals, globals example in frontend spec
- Ben: Collect of examples of tricky `if` statements with assignments
## 2021-08-17 Discussion meeting
### Agenda
1. Reevaluate/clarify project goals
2. Discuss examples
- [simple examples](#Simple-stencil-operations)
### Minutes
...
## 2021-08-05 Kick-off meeting
### Agenda
1. Refine agenda
2. Organization
- Plan for the first 4 weeks of the cycle
1. Definition of goals, features, test cases for the evaluation, organization of work and any other initial discussions
2. First iteration of proposed frontends: discussions and feedback based on the defined test cases. Implementation of the refined proposals.
3. Second iteration of proposed frontends: discussion and feedback from external collaborators, scientists and other potential users. Further refinements.
4. Design and discussion of the technical aspects of the final implementation: AST visitors vs tracing, architectural design, etc. Write a draft for the shaped project.
- Sync meetings:
+ How often?
+ On demand possible?
3. Discuss design goals (see [here](#Goals))
### Minutes
- Action points: collect code snippets covering all the patterns we would like to express with the frontend
+ Ben: ICON, Dusk or artificial examples
+ Johann: column-based examples ?
+ Till, Johann, Enrique: collect general examples
# Design
## Goals
- Primary goals:
+ Simple programming model: the frontend language should be easy to understand and explain to scientific users.
+ Development friendly (debuggability): the user should have simple means to debug the algorithms being implemented. One ideal way to achieve this goal would be that the frontend language is semantically valid Python code which can be executed and debugged with regular Python tools (`print()`, `pdb`, etc.). We call this approach **embedded** execution and it needs to fulfill some minimum performance considerations to be useful at scale.
+ Expressive and sound mapping to Iterator IR model
- Secondary questions:
+ Fits well with scientific Python programming conventions
+ Similarity with other existing programming conventions: in general it is easier to explain languages using familiar programming conventions (although there are often contradictory conventions in different languages and/or domains).
+ Extensibility and interoperability: the frontend language should not prevent extensions in the language itself, in the builtins, or the definition of low-level interfaces to add user-written functions to program.
- Open questions:
+ Starting with cartesian or unstructured dialects ?
+ Multiple frontends with different tradeoffs ?
+ Different languages at different levels:
* 1 -> Fencil/control flow level
* 2 -> Stencil/kernel level
* 1.5-> Operator level ???
## Features
Coarse overview of the features we want to cover:
- Arithmetic operations on fields
- Shift-like / neighborhood chains (multiple locations)
- Reductions
- ~~Runtime indirections~~<br>
*If the condition which neighbor is to be accessed is part of a stencil we can avoid this case (see [here](https://github.com/tehrengruber/gt4py_functional_frontend/blob/master/tests/iterator_tests/test_horizontal_indirection.py) for an example).*
## Evaluation test cases
**IMPORTANT**: the categorized and up-to-date collection of test cases is at: [GT4Py Frontend design - Evaluation test cases](https://hackmd.io/@gridtools/r13huPBHF).
A collection of code snippets to cover all patterns we would like to express in the frontend. The snippets should be expressed using the *semantic IR* as well as the proposed frontend syntax flavors.
__[WIP]__ This section is work-in-progress. We are not trying to design "the" frontend, but a user-friendly (see [goals](#Goals)) proposal that maps easily to the *semantic IR*.
~~A collection of code snippets covering all patterns we would like to express with the frontend. The snippets should be expressed as regular Python code with NumPy arrays as storages and later serve as tests for experimantal implementations.~~
Existing snippets:
- https://github.com/tehrengruber/gt4py_functional_frontend/tree/master/tests/iterator_tests
- [Collection from last cycle](https://hackmd.io/rT-Uwd2ASq6-h5rdlovZ_g#Use-cases)
- [Classic unstructured prototype unit-tests](https://github.com/GridTools/gt4py/blob/master/tests/gtc_unstructured_tests/unit_tests/stencil_definitions.py)
### Simple stencil operations
**General program structure**
```python=
T = TypeVar("T")
OneOrMore = Union[T, Tuple[T, ...]]
DT = TypeVar("DT", bound=gt.ScalarType)
# Field typing conventions
field: Field[[*Axes], DT]
# Axes names: V or Vertex, V2V or VV or V_V, ...
# Decorator to define local operators (TODO: define options)
@gt.local_operator(...)
def stencil(...):
return ...
# Feature for debugging & testing: direct execution of stencils
stencil[my_domain](...)
# TODO(tehrengruber): tentative
# mapping from axis to connectivity / neighbor table
# user can provide something different (e.g. diamond connectivit)
connectivities = {
# Connectivity[From, To, max_neighbors, has_skip_values]
E2V: Connectivity[E, V, 2, False](range(0, 10), [[0, 1], ...])
# ...
}
# Decorator to define 'fencils', which are entry points/compiled units
# of the generated code from Python
@gt.program(connectivities=connectivities, ...)
def fencil_def(
field_a: Field[[V], DT],
field_b: Field[[V], DT],
out_field: Field[[V], DT],
):
vertices = gt.IndexRange[Vertex](3, 10)
# Closure syntax:
# stencil[domain](*inputs, out=[outputs]) -->
# apply_stencil(domain, stencil, [inputs], [outputs])
stencil[vertices](field_a, field_b, out=[out_field])
```
1. **Regular horizontal stencil**
+ _Semantic Model IR:_
```python=
@fundef
def sum_of_all_vertex_neighbors_of_a_vertex(inp):
# length(V2V) == 4
return (
deref(shift(V2V, 0)(inp))
+ deref(shift(V2V, 1)(inp))
+ deref(shift(V2V, 2)(inp))
+ deref(shift(V2V, 3)(inp))
)
```
+ _Lifted iterator view:_
```python=
def sum_of_all_vertex_neighbors_of_a_vertex(inp):
# length(V2V) == 4
return (
inp(V2V[0]) + inp(V2V[1]) + inp(V2V[2]) + inp(V2V[3])
) # It would be an out-of-bounds error if `V2V[i]` is not pointing to a valid position
def sum_of_all_vertex_neighbors_of_a_vertex(inp):
# not all length(V2V) have to be the same
return sum(inp(v) for v in V2V)
```
+ _Indexing proposal:_
```python=
@gt.local_operator
def sum_of_all_vertex_neighbors_of_a_vertex(
inp: Field[[V], DT], *, position: Position[V]
):
(v,k) = position
v1, v2, v3, v4 = neighbors(v, V2V)
return inp[v1] + inp[v2] + inp[v3] + inp[v4]
# return sum(inp[v] for v in neighbors(v, V2V))
```
+ _Transforming views proposal:_
```python=
@gt.local_operator # Set domain kind: @gt.local_operator(domain=[V]) ??
def sum_of_all_vertex_neighbors_of_a_vertex(
inp: LocatedView[[V], [V], DT] #, *, position: Position[V] --> Strictly optional
) -> DT:
V2V: Connectivity[[V], [(V,) * VV]] = gt.connectivities("V2V") # length(*OutAxes) == VV
return inp(V2V(0)) + inp(V2V(1)) + inp(V2V(2)) + inp(V2V(3))
# alternative:
# v1, v2, v3, v4 = inp(v2v)
# return v1 + v2 + v3 + v4
```
Explanation:
```python=
class Connectivity(Protocol[[*InAxes], [*OutAxes]]):
def __call__(self, offset: int) -> Connectivity:
...
def __getitem__(self, position: Position[*InAxes]) -> Position[*OutAxes]:
...
e2v: Connectivity[Index[E], [Index[V], ...]] # length(*OutAxes) == EV
class LocatedView(Protocol[Position[*PosAxes], [*FieldAxes], DT]):
position: Position[*PosAxes]
field: Field[[*FieldAxes], DT]
def __call__(self, connectivity: Connectivity) -> LocatedView:
... # return new connectivity with mapped self.positrion
@property
def __value__(self):
return self.field[self.position
CoLocatedView = LocatedView[[*Axes], [*Axes], DT]
```
2. **Regular horizontal stencil with shifting**
+ _Semantic Model IR:_
```python=
@fundef
def v2e2cell_stencil(inp):
return deref(shift(V2E, 0)(shift(E2C, 0)(inp)))
```
+ _Indexing proposal:_
```python=
@gt.local_operator
def v2e2cell_stencil(inp: Field[[C], DT], *, position: Position[V]):
(v,) = position
c = neighbor(v, (V2E, 0), (E2C, 0))
return inp[c]
```
+ _Lifted iterator view:_
```python=
@fundef
def v2e2cell_stencil(inp: View[Cell]):
tmp1: View[Edge] = inp(E2C[0])
tmp2: View[Vertex] = tmp1(V2E[0])
# todo: discuss
# inp(V2E[0], E2C[0])
return tmp2
```
+ _Transforming views proposal:_
```python=
@gt.local_operator
def v2e2cell_stencil(inp: LocatedView[[V], [C], DT]) -> DT:
#V2E, E2C = gt.connectivities("V2E", "E2C")
return inp(V2E(0), E2C(0))
```
3. **Regular horizontal stencil with shifting and lifting**
+ _Semantic Model IR:_
```python=
# input is an iterator with E-position and a buffer defined on C-positions
@fundef
def first_cell_nb(inp):
# shift(E2C, 0)(inp) returns an iterator with
# C-position and a buffer defined on C-positions
return deref(shift(E2C, 0)(inp))
@fundef
def v2e2cell_stencil(inp):
# lift(first_cell_nb) takes an
# iterator with E-positions and a buffer defined on C-positions
# and returns an iterator with
# C-positions and the same buffer
return deref(lift(first_cell_nb)(shift(V2E, 0)(inp))
return deref(shift(V2E, 0)(lift(first_cell_nb)(inp)))
```
+ _Lifted iterator view:_
```python=
def first_cell_nb(inp):
return inp(E2C[0])
def v2e2cell_stencil(inp):
return first_cell_nb(inp(V2E[0]))
return first_cell_nb(inp)(V2E[0])
```
+ _Indexing proposal:_
```python=
@gt.local_operator
def first_cell_nb(inp, *, position: Position[E]):
(e,) = position
c = neighbor(e, (E2C, 0))
return inp[c]
@gt.local_operator
def v2e2cell_stencil(inp, *, position: Position[V]):
(v,) = position
# call directly
return first_cell_nb(inp, position=neighbor(v, (V2E, 0)))
# lift
tmp: Field[[E], DT] = first_cell_nb[...](inp)
return tmp[neighbor(v, (V2E, 0))]
```
+ _Transforming views proposal:_
```python=
@gt.local_operator
def first_cell_nb(inp: LocatedView[[E], [C], DT]) -> DT:
return inp(E2C(0))
@gt.local_operator
def v2e2cell_stencil(inp: LocatedView[[V], [C], DT]) -> DT:
tmp_on_cells: LocatedView[[C], [C], DT] = first_cell_nb[...](inp(V2E(0)))
return tmp_on_cells ## tmp_on_cells(0) --> explicit deref() ??
```
4. Regular horizontal stencil with reduction
+ _Semantic Model IR:_
```python=
@fundef
def sum_of_all_vertex_neighbors_of_a_vertex(inp):
return reduce(lambda a, b: a+b, shift(V2V)(inp))
```
+ _Lifted iterator view:_
```python=
def sum_of_all_vertex_neighbors_of_a_vertex(inp):
return sum(inp[v] for v in V2V)
```
+ _Indexing proposal:_
```python=
@gt.local_operator
def sum_of_all_vertex_neighbors_of_a_vertex(inp: Field[[V], DT], *, position: Position[V]):
(v,) = position
# reduce
return reduce(lambda a, b: a+b, inp[neighbors(v, V2V)])
# sum + slicing
return sum(inp[neighbors(v, V2V)])
# list compr + sum + slicing
return sum(inp[v2] for v2 in neighbors(v, V2V))
```
+ _Transforming views proposal:_
```python=
@gt.local_operator
def sum_of_all_vertex_neighbors_of_a_vertex(inp: LocatedView[[V], [V], DT]) -> DT:
v_neighbors: LocatedView[[V, VV]] = v2v_conn(inp)
# reduce
return reduce(lambda a, b: a + b, inp(V2V), init=0)
# sum + slicing
return sum(inp(V2V))
# list compr + sum + slicing
return sum(neighbor_v for neighbor_v in inp(V2V))
```
5. Sparse-field
TODO (tehrengruber): How to select `inp[neighbor(v, V2V, 1), 1]` in the semantic model?
+ _Semantic Model IR:_
```python=
@fundef
def sum_of_all_vertex_neighbors_of_a_vertex(inp: Iterator[[V, ~V2V], [V, V2E]]: Iterator[[V], Field[V, DT]]):
tmp = shift(V2V, 0)(inp)
return deref(tmp)
#return (
# deref(shift(0)(inp))
# + deref(shift(1)(inp))
# + deref(shift(2)(inp))
# + deref(shift(3)(inp))
#)
```
+ _Indexing proposal:_
```python=
@gt.local_operator
def sum_of_all_vertex_neighbors_of_a_vertex(inp: Field[[V, V2V], DT], *, position: Position[V]):
(v,) = position
return sum(inp[v, :])
return inp[v, 0]+inp[v, 1]+inp[v, 2]+inp[v, 3]
```
+ _Transforming views proposal:_
```python=
@gt.local_operator
def sum_of_all_vertex_neighbors_of_a_vertex(inp: LocatedView[[V], [V, V2V], DT]) -> DT:
# a: LocatedView[[V, V2V], [V, ~V2V] = inp(None)
return inp[0] + inp[1] + inp[1] + inp[1]
# TODO: define and discuss the issue of dereferencing iterators/views
# with indices of lower dimensionality
```
6. Nabla
+ _Semantic Model IR:_
```python=
@fundef
def compute_zavgS(pp, S_M):
zavg = 0.5 * (deref(shift(E2V, 0)(pp)) + deref(shift(E2V, 1)(pp)))
# zavg = 0.5 * reduce(lambda a, b: a + b, 0)(shift(E2V)(pp))
# zavg = 0.5 * library.sum()(shift(E2V)(pp))
return deref(S_M) * zavg
@fundef
def compute_pnabla(pp:Iterator[V], S_M: Iterator[], sign, vol):
zavgS = lift(compute_zavgS)(pp, S_M)
# pnabla_M = reduce(lambda a, b, c: a + b * c, 0)(shift(V2E)(zavgS), sign)
# pnabla_M = library.sum(lambda a, b: a * b)(shift(V2E)(zavgS), sign)
pnabla_M = library.dot(shift(V2E)(zavgS), sign)
return pnabla_M / deref(vol)
# dot impl with zip?
@fundef
def zip_kern(a, b):
return (deref(a), deref(b))
@fundef
def zip(a, b):
return shift(V2V)(lift(zip_kern(a, b)))
@fundef
def dot(a, b):
return reduce(lambda prev, c: prev+c[0]*c[1], zip(a, b), 0)
```
+ _Lifted iterator view:_
```python=
def compute_zavgS(pp: View[Vertex], S_M: View[Edge]) -> View[Edge]:
zavg: View[Edge] = 0.5 * (pp(E2V[0]) + pp(E2V[1]))
# zavg = lift(lambda a, b: 0.5 * (deref(shift(E2V, 0)(a)) + deref(shift(E2V, 1)(a))))(pp)
return S_M * zavg
# return lift(lambda a,b: deref(a)*deref(b))(S_M, zavg)
def compute_pnabla(pp: View[Vertex], S_M: View[Edge], sign, vol: View[Vertex]) -> View[Vertex]:
zavgS: View[Edge] = compute_zavgS(pp, S_M)
pnabla_M: View[Vertex] = sum(
zavgS(v) * sign[i] for i, v in enumerate(V2E)
) # TODO I don't like adressing the sign field with `i`, dusk solves this by inprecise syntax
# pnabla_M = lift(reduce( lambda a, b, c: a+b*c, 0))(shift(V2E)(zavgS), sign)
return pnabla_M / vol
# return lift(lambda a,b: deref(a)/deref(b))(pnabla_M, vol)
```
+ _Field view:_
```python=
@field_operator
def compute_zavgS(e2v, pp, S_M):
zavg = 0.5 * gt.sum(pp[e2v], axis="E2V")
return S_M * zavg
@field_operator
def compute_pnabla(v2e, e2v, pp, S_M, sign, vol):
zavgS = compute_zavgS(e2v, pp, S_M)
pnabla_M = gt.sum(zavgS[v2e] * sign, axis="E2V")
out = gt.if_(pp > 6 and S_M < 4, pp +4, pp - S_M)
return pnabla_M / vol
```
+ _Indexing proposal:_
```python=
@local_operator
def compute_zavgS(pp: Field[[V, K], DT], S_M: Field[[E], DT], *, position: Position[E, K]):
(e, k) = position
v1, v2 = neighbors(e, E2V) # or short-hand: vertices(e)
zavg = 0.5 * pp[v1, k] + pp[v2, k]
return S_M[e] * zavg
@local_operator
def dot(a: Field[LocalField, DT], b: Field[LocalField, DT]):
return sum(ai*bi for ai, bi in zip(a, b))
@local_operator
def compute_pnabla(pp: Field[[V], DT], S_M: Field[[E], DT], sign: Field[[V, V2E], DT], vol: Field[[V], DT], *, position: Position[V]):
(v,) = position
zavgS = compute_zavgS[...](pp, S_M)
pnabla_M = dot(zavgS[neighbors(v, V2E), k], sign[v, :])
return pnabla_M / vol[v]
# todo: discuss with EGP
#zavgS_on_nb_edges: Field[[V2E], DT] = zavgS[v2e[v, :]]
```
+ _Transforming views proposal:_
```python=
#TODO: Fix the lifted part
# @local_operator
# def compute_zavgS(
# pp: LocatedView[[E, K], [V, K], DT], S_M: LocatedView[[V, K], [E], DT]
# ) -> DT:
# zavg = 0.5 * pp(E2V(0)) + pp(E2V(1))
# return S_M * zavg
# @local_operator
# def compute_pnabla(
# pp: LocatedView[[V, K], [V, K], DT],
# S_M: LocatedView[[V, K], [E], DT],
# sign: LocatedView[[V, K], [V, V2E], DT],
# vol: LocatedView[[V, K], [V], DT],
# ):
# zavgS = compute_zavgS[...](pp, S_M)
# pnabla_M = dot(zavgS(E2V), sign[v, :])
# return pnabla_M / vol[v]
```
7. [WIP] ICON reductions & dusk/dawn weights
Similar examples exist in ICON
+ dusk syntax:
```python=
@stencil
def reductions_and_weights(
edges_a: Field[Edge],
cells_a: Field[Cell],
vertices_a: Field[Vertex],
):
with domain.upward:
# These examples all involve the following two triangle shape:
# (we call this the diamond shape, orientation is irrelevant)
#
# *
# / \
# / \
# *-----*
# \ /
# \ /
# *
# simple difference between neighboring cells of an edge:
edges_a = sum_over(Edge > Cell, cells_a, weights=[-1, 1])
# difference between the upper and lower vertex of the diamond
edges_a = sum_over(Edge > Cell > Vertex, vertices_a, weights=[-1, 1, 0, 0])
# sum of the upper and lower vertex of the diamond (selective/filtered sum)
edges_a = sum_over(Edge > Cell > Vertex, vertices_a, weights=[1, 1, 0, 0])
```
There are multiple ways to express this, the goal should be that the
front-end allows for concise and readable way to express these examples.
+ Field view syntax:
```python=
@gt.field_operator
def nabv_ref(weights, u_vert, v_vert, primal_normal_vert_v1, primal_normal_vert_v2):
return np.sum(
(
u_vert[diamond_arr] * primal_normal_vert_v1
+ v_vert[diamond_arr] * primal_normal_vert_v2
)
* weights,
axis=-1,
)
@gt.field_operator
def nabv_tang_ref(u_vert, v_vert, primal_normal_vert_v1, primal_normal_vert_v2):
weights = np.asarray([[1.0, 1.0, 0.0, 0.0]] * n_edges)
return nabv_ref(
weights, u_vert, v_vert, primal_normal_vert_v1, primal_normal_vert_v2
)
@gt.field_operator
def nabv_norm_ref(u_vert, v_vert, primal_normal_vert_v1, primal_normal_vert_v2):
weights = np.asarray([[0.0, 0.0, 1.0, 1.0]] * n_edges)
return nabv_ref(
weights, u_vert, v_vert, primal_normal_vert_v1, primal_normal_vert_v2
)
@gt.field_operator
def z_nabla4_e2_ref(
nabv_norm, nabv_tang, z_nabla2_e, inv_vert_vert_length, inv_primal_edge_length
):
return 4.0 * (
(nabv_norm - 2.0 * z_nabla2_e) * inv_vert_vert_length ** 2
+ (nabv_tang - 2.0 * z_nabla2_e) * inv_primal_edge_length ** 2
)
```
8. [WIP] Scan pass
+ _Semantic Model IR:_
```python=
@fundef
def sum_scanpass(state, inp):
return if_(is_none(state), deref(inp), state + deref(inp))
@fundef
def ksum(inp):
return scan(sum_scanpass, True, None)(inp)
@fendef(column_axis=KDim)
def ksum_fencil(i_size, k_size, inp, out):
closure(
domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)),
ksum,
[out],
[inp],
)
```
+ _Indexing proposal:_
```python=
# option 1
@gt.local_operator
def sum_scanpass(state: Optional[DT], inp: Field[], *, position):
(v, k) = position
if state is None:
return inp[v, k]
return state+inp[v, k]
# option 2
@gt.scan_pass(forward=True, init=None, column_axis=K)
def sum_solver(state, inp, *, position):
# ... same as before
@gt.column_operator(column_axis=K)
def ksum(inp: Field[[V, K]], inp_2d: Field[[V]]):
(v,) = position
# option 1
return scan(sum_scanpass, forward=True, init=None)(inp, position=position)
# option 2
return sum_solver(inp, position=position)
@gt.program(...)
def ksum_fencil(i_size, k_size, inp, out):
domain = IndexRange[IDim](0, i_size)*IndexRange[IDim](0, k_size)
return ksum[domain](inp, out=out)
```
+ _Transforming views proposal:_
```python=
# option 1
@gt.local_operator
def sum_scanpass(
state: Optional[DT], inp: LocatedView[[*Axes, K], [*Axes, K], DT]
):
if state is None:
return inp
return state + inp
# option 2
@gt.scan_pass(forward=True, init=None, column_axis=K)
def sum_solver(state, inp):
# ... same as before
@gt.column_operator(column_axis=K)
def ksum(inp: LocatedView[[*Axes, K], [*Axes, K], DT]):
# option 1
return scan(sum_scanpass, forward=True, init=None)(inp, position=position)
# option 2
return sum_solver(inp, position=position)
@gt.program(...)
def ksum_fencil(i_size, k_size, inp, out):
domain = IndexRange[IDim](0, i_size)*IndexRange[IDim](0, k_size)
return ksum[domain](inp, out=out)
```
9. Tridiag solver
+ _Semantic Model IR:_
```python=
@fundef
def tridiag_forward(state, a, b, c, d):
# not tracable
# if is_none(state):
# cp_k = deref(c) / deref(b)
# dp_k = deref(d) / deref(b)
# else:
# cp_km1, dp_km1 = state
# cp_k = deref(c) / (deref(b) - deref(a) * cp_km1)
# dp_k = (deref(d) - deref(a) * dp_km1) / (deref(b) - deref(a) * cp_km1)
# return make_tuple(cp_k, dp_k)
# variant a
# return if_(
# is_none(state),
# make_tuple(deref(c) / deref(b), deref(d) / deref(b)),
# make_tuple(
# deref(c) / (deref(b) - deref(a) * nth(0, state)),
# (deref(d) - deref(a) * nth(1, state))
# / (deref(b) - deref(a) * nth(0, state)),
# ),
# )
# variant b
def initial():
return make_tuple(deref(c) / deref(b), deref(d) / deref(b))
def step():
return make_tuple(
deref(c) / (deref(b) - deref(a) * nth(0, state)),
(deref(d) - deref(a) * nth(1, state)) / (deref(b) - deref(a) * nth(0, state)),
)
return if_(is_none(state), initial, step)()
@fundef
def tridiag_backward(x_kp1, cp, dp):
# if is_none(x_kp1):
# x_k = deref(dp)
# else:
# x_k = deref(dp) - deref(cp) * x_kp1
# return x_k
return if_(is_none(x_kp1), deref(dp), deref(dp) - deref(cp) * x_kp1)
@fundef
def solve_tridiag(a, b, c, d):
tup = lift(scan(tridiag_forward, True, None))(a, b, c, d)
cp = nth(0, tup)
dp = nth(1, tup)
return scan(tridiag_backward, False, None)(cp, dp)
@fendef
def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x):
closure(
domain(
named_range(IDim, 0, i_size),
named_range(JDim, 0, j_size),
named_range(KDim, 0, k_size),
),
solve_tridiag,
[x],
[a, b, c, d],
)
```
+ _Indexing proposal:_
```python=
@gt.local_operator
def tridiag_forward(state, a, b, c, d, *, position):
v, k = position
if state is None:
return (c[v, k] / b[v, k], d[v, k] / b[v, k])
cp_km1, dp_km1 = state
cp_k = c[v, k] / (b[v, k] - a[v, k] * cp_km1)
dp_k = (d[v, k] - a[v, k] * dp_km1) / (b[v, k] - a[v, k] * cp_km1)
return (cp_k, dp_k)
@gt.local_operator
def tridiag_backward(x_kp1, cp, dp, *, position):
(v, k) = position
if x_kp1 is None:
return dp[v, k]
return dp[v, k]-cp[v, k] * x_kp1
@gt.column_operator(column_axis=K)
def solve_tridiag(a, b, c, d, *, position):
(v,) = position
# option 1
# cp, dp: Field[[K], DT], Field[[K], DT]
# forward = scan(tridiag_forward, True, None)
# cp. dp = forward[...](a, b, c, d)
cp, dp = scan(tridiag_forward, True, None)[...](a, b, c, d)
return scan(tridiag_backward, forward=False, init=None)(cp, dp)
# option 2
cp, dp = tridiag_forward[...](a, b, c, d)
return tridiag_backward(cp, dp, position=position)
@gt.program(...)
def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x):
domain = IndexRange[IDim](0, i_size)*IndexRange[JDim](0, j_size)*IndexRange[KDim](0, k_size)
solve_tridiag[domain](a, b, c, d, out=[x])
```
+ _Transforming views proposal:_
```python=
@gt.local_operator
def tridiag_forward(
state: Tuple[float, float], a: LocatedView[[*Axes, K], [*Axes, K], DT], b : ..., c, d
):
if state is None:
return (c / b, d / b)
cp_km1, dp_km1 = state
cp_k = c / (b - a * cp_km1)
dp_k = (d - a * dp_km1) / (b - a * cp_km1)
return (cp_k, dp_k)
@gt.local_operator
def tridiag_backward(x_kp1: float, cp: ..., dp):
if x_kp1 is None:
return dp
return dp - cp * x_kp1
@gt.column_operator(column_axis=K)
def solve_tridiag(a, b, c, d):
# option 1
# cp, dp: Field[[K], DT], Field[[K], DT]
# forward = scan(tridiag_forward, True, None)
# cp. dp = forward[...](a, b, c, d)
cp, dp = scan(tridiag_forward, True, None)[...](a, b, c, d)
return scan(tridiag_backward, forward=False, init=None)(cp, dp)
# option 2
cp, dp = tridiag_forward[...](a, b, c, d)
return tridiag_backward(cp, dp, position=position)
@gt.program(...)
def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x):
domain = IndexRange[IDim](0, i_size)*IndexRange[JDim](0, j_size)*IndexRange[KDim](0, k_size)
solve_tridiag[domain](a, b, c, d, out=[x])
```
10. Laplacian Cartesian
+ _Semantic Model IR:_
```python=
@fundef
def laplacian(inp):
return -4.0 * deref(inp) + (
deref(shift(I, 1)(inp))
+ deref(shift(I, -1)(inp))
+ deref(shift(J, 1)(inp))
+ deref(shift(J, -1)(inp))
)
```
+ _Indexing proposal:_
```python=
@gt.local_operator
def laplacian(inp: Field[[I, J], DT], *, position):
(i, j) = position
return -4.0 * inp[i, j] + (inp[i+1, j] + inp[i-1, j] + inp[i, j+1] + inp[i, j-1])
```
```python=
@gt.local_operator
def laplacian(inp: Field[[I, J], DT], *, pos):
return -4.0 * inp[pos] + (inp[pos + i] + inp[pos - i] + inp[pos + j] + inp[pos - j])
```
+ [WIP] _Accessor proposal:_
```python=
@gt.local_operator(domain=Domain[I, J, ...])
def laplacian(inp: Accessor[[I+1/2, J, K], DT]):
return -4.0 * inp() + (inp(I+1) + inp(I-1) + inp(J+1) + inp(J-1))
```
__todo(tehrengruber): add example for staggered protectetion__
```
0 1 2 3
1/2 3/2
1 2
x---x---x---x
| C | C | C |
x---x---x---x
| C | C | C |
x---x---x---x
f_i = Field[I]
f_ie = Field[I+1/2]
stencil[I+1/2](f_i, f_ie)
def stencil(f_i, f_ie):
return 2.0 + f_i[-0.5] + f_ie
f_i: Accessor[I]
f_ie: Accessor[I+1/2]
f_ie[0.5]: Accessor[I]
```
Position[[I, J, K]]
+ _Transforming views proposal:_
```python=
```
11. Horizontal diff Cartesian
+ [WIP] _Accessor proposal:_
```python=
@gt.local_operator(domain=Domain[I+1/2, J, ...])
def flux_x(inp: Accessor[[I, J], DT]):
lap = laplacian[...](inp)
flux = lap(I-1/2) - lap(I+1/2)
if flux * inp(I+1/2) - I(I-1/2) > 0.0:
return 0
return flux
# tentative
flux_y = gt.permute_axis(flux_x, {I: J, J: I})
```
12. [WIP] BCs Cartesian
+ _Accessor proposal:_
```python=
@gt.local_operator(domain=(I, J))
def laplacian(boundary: IndexRange[I, J], inp: Accessor[[I, J], DT]):
if i, j in boundary:
return 0.
return -4.0 * inp() + (inp(I+1) + inp(I-1) + inp(J+1) + inp(J-1))
@gt.local_operator(domain=(I, J))
def laplacian(inp: Accessor[[I, J], DT]):
if i, j in gt.domain[1:-1, 1:-1]:
return 0.
return -4.0 * inp() + (inp(I+1) + inp(I-1) + inp(J+1) + inp(J-1))
```
13. BCs Unstructured
```python=
@gt.local_operator
def laplacian(inp: Field[[V], DT], on_boundary: Field[[V], bool], *, position: Position[V]):
(v,) = position
if on_boundary:
return 0.
return -4.0 * inp[v] + sum(inp[v2] for v2 in neighbors(v, V2V))
```
```python=
@gt.local_operator
def laplacian(inp: Field[[V], DT], *, position: Position[V]):
(v,) = position
if v in boundary_vertices:
return 0.
return -4.0 * inp[v] + sum(inp[v2] for v2 in neighbors(v, V2V))
```
14. Todo: PBCs Cartesian
15. FVM upstream
- *Positional field composition*
```python=
@gt.local_operator
def upstream_x(grid, rho: Field[(I, J), DT], velx: DT, *, position: Position[I+1/2, J]):
i, j = position
if (i, j) in grid.staggered_domain["I"][:1, :]:
rho_lb = rho[i, j]
return (rho_lb if velx > 0 else rho[i, j])*velx
elif (i, j) in grid.staggered_domain["I"][-1:, :]:
rho_rb = rho[i-1, j]
return (rho[i-1, j] if velx > 0 else rho_rb)*velx
return (rho[i-1, j] if velx > 0 else rho[i, j])*velx
@gt.local_operator
def upstream_y(grid, rho: Field[(I, J), DT], vely: DT, *, position: Position[I, J+1/2]):
i, j = position
if (i, j) in grid.staggered_domain["J"][:, :1]:
rho_lb = rho[i, j]
return (rho_lb if vely > 0 else rho[0, 0])*vely
elif (i, j) in grid.staggered_domain["J"][:, -1:]:
rho_rb = rho[i, j-1]
return (rho[i, j-1] if vely > 0 else rho_rb)*vely
return (rho[i, j-1] if vely > 0 else rho[i, j])*vely
@gt.local_operator
def advector(dt, rho: Field[(I, J), DT], flux_x: Field[(I+1/2, J), DT], flux_y: Field[(I, J+1/2), DT] *, position: Position[I, J]):
i, j = position
return rho[i, j] - dt*(flux_x[i-1/2, j]-flux_x[i+1/2, j]+flux_y[i, j-1/2]-flux_y[i, j+1/2])
@gt.local_operator
def sub(rho: Field[(I, J), DT], flux_x: Field[(I+1/2, J), DT], *, position: Position[I, J]):
i, j = position
return rho[i, j] - (flux_x[i-1/2, j]
@gt.field_operator
def advect(grid, rho: Field[(I, J), DT], vel_x: DT, vel_y: DT, dt: DT):
flux_x = upstream_x[grid.staggered_domain["I"]](grid, rho, vel_x)
flux_y = upstream_y[grid.staggered_domain["J"]](grid, rho, vel_y)
return advector[grid.cell_domain](dt, rho, flux_x, flux_y)
@gt.local_operator
def init_rho(*, position: Position[I, J]):
i, j = position
return 1 if I(10)<i<I(20) else 0
def advect_fvm():
Nx, Ny = 50, 10
# setup grid
# TODO: wrong
cell_domain = ProductSet.from_shape((Nx, Ny))
staggered_domain = {
"I": cell_domain.extend((0, 1), 0),
"J": cell_domain.extend(0, (0, 1))}
grid = SimpleNamespace(cell_domain=cell_domain, staggered_domain=staggered_domain)
# initial data setup
rho = init_rho[cell_domain]() # step function
vel_x, vel_y = 1, 0
dt = 1
t_end = 10
# advect density
t = 0
while t<t_end:
advect(grid, rho, vel_x, vel_y, dt, out=rho)
plot_field(rho)
t+=dt
```
- *Field view*
```python=
@gt.field_operator
def upstream_x(
grid, rho: Field[(I, J), DT], velx: Field[(I, J), DT]
) -> Field[(I+1/2, J), DT]:
rho_lb = rho(I-1)
first = where(velx > 0, rho_lb, rho) * velx
rho_rb = rho(I - 1)
last = where(velx > 0, rho(I - 1), rho_rb) * vely
interior = where(velx > 0, rho(I - 1), rho) * vely
# interior = interior[domain[:, 1:-1]]
domain = grid.staggered_domain["I"]
return combine(
rho(I-1)[rho.domain[:1, :]], interior[domain[1:-1, :]], last[domain[-1:, :]]
)(I+1/2)
# return piecewise([rho.domain[:, 1], rho.domain[:, -1]], [rho(I-1), rho**2, rho(I+1)])
# return (first[domain[:1, :]] | interior[domain[1:-1, :]] | last[domain[-1:, :]])(I+1/2)
@gt.field_operator
def upstream_y(
grid, rho: Field[(I, J), DT], velx: Field[(I, J), DT]
) -> Field[(I, J+1/2), DT]:
rho_lb = rho + 1.0
first = where(vely > 0, rho_lb, rho) * vely
rho_rb = rho(J - 1)
last = where(vely > 0, rho_lb, rho) * vely
interior = where(vely > 0, rho(J - 1), rho) * vely
# interior = interior[domain[:, 1:-1]]
domain = grid.staggered_domain["J"]
return combine(
first[domain[:, :1]], interior[domain[:, 1:-1]], last[domain[:, -1:]]
)(J+I/2)
# return first[domain[:, :1]] | interior[domain[:, 1:-1]] | last[domain[:, -1:]](J+I/2)
| C |
x x
x X flux_x(I-1/2)
X x flux_x(I+1/2)
X flux_x(I-1/2) - flux_x(I+1/2)
H | C | H
x x
x x edge_field(C2E, 0)
x x edge_field(C2E, 1)
X edge_field(C2E, 0) - edge_field(C2E, 1)
| C | C |
x x
x x cell_field(C2E, 0)
x x cell_field(C2E, 1)
X cell_field(C2E, 0) - cell_field(C2E, 1)
I_m_half = [[0_e]]
I_p_half = [[1_e]]
flux_x[x_e]
out[0_c] = flux_x[I_m_half[0_c][0]]
out = [0]*len(I_m_half)
for c in len(I_m_half):
out[c] = flux_x[I_m_half[c][0]]
out = remap(I_m_half)(flux_x)
| C | C | C | C |
X
X X
I I I B
e_avg = average(c_in) # 5 edges
@gt.field_operator
def advector(dt, rho: Field[(I, J), DT], flux_x: Field[(I+1/2, J), DT], flux_y: Field[(I, J+1/2), DT]):
#equivalent:
# flux_x[1:, :].remap(I-1/2)
# Field(cell_domain, flux_x[1:, :].data)
return rho - dt * (
flux_x(I-1/2) - flux_x(I+1/2)
+ flux_y[:, 1:](J-1/2) - flux_y[:, :-1](J+1/2)
)
@gt.field_operator
def advect(rho: Field[(I, J+1/2), DT]):
result = rho**2
another_result = result[I-1/2] - 2
return result, another_result[:-1, :]
@gt.field_operator
def advect(grid, rho: Field[(I, J), DT], vel_x: DT, vel_y: DT, dt: DT):
flux_x = upstream_x(grid, rho, vel_x)[grid.staggered_domain["I"]]
# flux_x = upstream_x(grid, rho, vel_x)[grid.staggered_domain["I"]]
flux_y = upstream_y(grid, rho, vel_y)
# flux_y = upstream_y(grid, rho, vel_y)[grid.staggered_domain["J"]]
return advector(dt, rho, flux_x, flux_y)[rho.domain]
@gt.field_operator
def init_rho(cell_domain):
index_field = Field(domain=cell_domain, data=cell_domain)
index_field = index_field(cell_domain)
return gt.where(I(10) < index_field.I_data < I(20), 0, 1)
# left = Field(cell_domain & (I[0:10] * J[:]), data=0)
# left = Field(cell_domain[0:10, ...], data=0)
# center = Field(cell_domain[10:20, ...], data=1)
# right = Field(cell_domain[0:10, ...], data=0)
# return combine(left, center, right)
def advect_fvm():
Nx, Ny = 50, 10
# todo: wrong
# setup grid
cell_domain = ProductSet.from_shape((Nx, Ny))
staggered_domain = {
"I": cell_domain.extend((0, 1), 0),
"J": cell_domain.extend(0, (0, 1))}
grid = SimpleNamespace(cell_domain=cell_domain, staggered_domain=staggered_domain)
# initial data setup
rho = init_rho(cell_domain) # step function
vel_x, vel_y = 1, 0
dt = 1
t_end = 10
# advect density
t = 0
while t<t_end:
advect(grid, rho, vel_x, vel_y, dt, out=rho)
plot_field(rho)
t+=dt
```
- 16. Elliptic solver
*Positional with field composition*
```python=
@gt.local_operator
def laplacian(domain, phi: Field[(I, J), DT], *, position: Position[I, J]):
if position not in domain[1:-1, 1:-1]:
return -(phi[i-1, j]-2*phi[i, j]+phi[i+1, j])/(dx*dx) - (phi[i, j-1]-2*phi[i, j]+phi[i, j+1])/(dy*dy)
return 0.
@gt.field_operator
def step(domain, β: DT, L: Field[(I, J), DT], ϕ: Field[(I, J), DT], r: Field[(I, J), DT]):
ϕ = ϕ + β * r # update solution
r = r + β * L_ # update residual
return ϕ, r
# driver code
def steepest_descent(ϕ_0, L, R, ϵ=0.01, max_it=10000):
ϕ = ϕ_0 # solution
r = L(ϕ)-R # residual
# ensure BCs are respected
r[0, :]=r[:, 0]=r[-1, :]=r[:, -1]=0
for it in range(0, max_it):
L = laplacian[domain](domain, r)
β: DT = - sum(r*r)/sum(r * L_) # step size
step(β, L, ϕ, r, out=(ϕ, r))
# debug output
print(f"it: {it}, ||r||_L2 = {np.linalg.norm(r)}")
#if plotting_enabled:
# plot(ϕ)
# break when converged
if np.linalg.norm(r) <= ϵ:
break
return ϕ
```
- 17. Average
*Positional with field composition*
```python=
def backward_diff(inp):
return forward_diff(inp)[i-1]
@gt.local_operator
def laplacian(inp):
return -4*inp[i, j]+inp[i-1, j]+inp[i+1, j]+inp[i, j-1]+inp[i, j+1]
return (inp[i-1] - inp) - (+inp - inp[i+1]) + (...)
return forward_diff(inp)[i-1] - forward_diff(inp)
return backward_diff(inp) - forward_diff(inp)
@gt.local_operator
def forward_diff_i(inp):
return inp[i+1, j]-inp[i, j]
@gt.local_operator
def central_diff_i(inp):
return inp[i+1, j]-inp[i-1, j]
@gt.field_operator
def laplacian(inp):
@gt.local_operator
def average(inp: Field[(Edge, K), DT], *, position: Position[Vertex, K]) -> DT:
v, k = position
return sum(inp[e, k] for e in neighbors(v, V2E))
@gt.local_operator
def average_same(inp: Field[(Vertex, K), DT], *, position: Position[Vertex, K]):
v, k = position
return sum(inp[v, k] for v in neighbors(v, V2V))
@gt.field_operator
def average_composition(vertex_domain, inp: Field[(Edge, K), DT]):
tmp: Field[Vertex, DT] = average[vertex_domain](inp)
return average_same[vertex_domain](tmp)
def driver():
vertex_domain = Domain((Vertex, 0, 10), (K, 0, 10))
edge_domain = Domain((Edge, 0, 10), (K, 0, 10))
inp_edge = Field.from_function(edge_domain, lambda e, k: 42)
out_vertex = Field.from_function(vertex_domain, lambda v, k: 0)
average[vertex_domain](inp_edge, out=out_vertex)
average_composition(vertex_domain, inp_edge, out=out_vertex)
```
- Different iterator positions in the same stencil
*Beautified Iterator*
```python=
def foo(it0, it1):
return it0() + it1()
def bar(inp):
foo(inp, shift(I+1)(inp))
```
*Positional*
```python=
@localop
def foo(inp, inp2, *, position):
return inp[position] + inp2[position]
@localop
def shift(inp, dim, offset, *, position):
pos = position.shift(dim, offset)
return inp[pos]
@fieldop
def bar(inp):
foo[...](inp, shift[...](inp, I, 1 ))
```
---
TODO: calling "fencils"
TODO: non-field arguments, gridtools global parameters
map
## Questions
Question: Is it a problem that the user can accidentally assume to have only one position, but the caller passes local views on different positions? How likely do this problems appear and how hard is it to find such a problem if it appears?
```python=
@local_operator
def compute_zavgS(pp: LocatedView[[E, K], [], DT], S_M: Field[[E], DT]):
assert a.pos == b.pos
return a+b
```
Indexing: signature defines how many positions
Transforming views: every argument has its own position
# Frontend specifications
## Field view (WIP)
### Roadmap for start of 2022
1. Take another look at annotations canonicalization for externals & fieldop body, parsed & embedded
2. Merge Externals work
3. Resolve conflicts in Typing work and merge
4. Grab some simple [ealuation test cases](https://hackmd.io/OrYAdREkRSCUs-LjUEZo7A) and make them work
- make a way to get offset info at parsing time (might later on interact with embedded execution)
- track field dimension types through operations in FOAST typing (optionally?)
- lifting in lowering
5. Create FieldOperatorAST -> IteratorIR-Program machinery
### Toolchain entry points / decorators
```python=
# note: typing here is just an approximation since this is actually invalid
P = ParamSpec("P")
T = TypeVar("T")
OneOrMore = T | Tuple[T, ...]
def field_operator(
func: Callable[P, OneOrMore[T]]
) -> FieldOperator[P.args, P.kwargs + "out"], OneOrMore[T]: ...
# `func` arg should be a function where some argument names are forbidden:
# - "out": used in fencil calls
# - r"__[a-zA-Z_]+__": reserved by gt4py ?
# 'func' arguments typing annotations are non-optional but the actual
# format has not been yet decided
```
**Open questions:**
- Management of external symbols
### Builtins
#### Constant fields
```python=
def constant(
value: str | number, *, dtype=None, dimensions = None
) -> Field: ...
f32 = functools.partial(constant, dtype=gt4py.float32)
f64 = functools.partial(constant, dtype=gt4py.float64)
i32 = functools.partial(constant, dtype=gt4py.int32)
i64 = functools.partial(constant, dtype=gt4py.int64)
```
#### Mathematical functions
```python=
def abs(f: Field) -> Field: ...
def max(f: Field, g: Field) -> Field: ... # allow multiple arguments ??
def min(f: Field, g: Field) -> Field: ... # allow multiple arguments ??
def mod(f: Field, g: Field) -> Field: ...
def sin(f: Field) -> Field: ...
def cos(f: Field) -> Field: ...
def tan(f: Field) -> Field: ...
def arcsin(f: Field) -> Field: ...
def arccos(f: Field) -> Field: ...
def arctan(f: Field) -> Field: ...
def sqrt(f: Field, g: Field) -> Field: ...
def exp(f: Field) -> Field: ...
def log(f: Field, base: Field = math.e) -> Field: ...
def isfinite(ff: Field -> Field: ...
def isinf(f: Field) -> Field: ...
def isnan(f: Field) -> Field: ...
def floor(f: Field -> Field: ...
def ceil(f: Field) -> Field: ...
def trunc(f: Field) -> Field: ...
```
#### Control flow
```python=
def where(
mask: Field[bool], true_branch: Field, else_branch: Field
) -> Field: ...
if_ = where
```
#### Neighbor reductions
```python=
def sum(f: Field, axis: Dimension) -> Field: ...
```
TODO: discuss the following problem
```
e: Field[Edge]
e_e2v: Field[Edge, E2V] = e(E2V)
e_e2v_0: Field[Edge] = e(E2V(0)) # adds dimension
e_e2v_0: Field[Edge] = e_e2v[E2V(0)] # removes dimension
```
#### Advanced patterns
```python=
def scan(
scan_pass: Optional[FieldOperator[[State, *Fields], OneOrMore[Field]]]] = None,
init: Optional[Any] = None
*,
reverse: bool = False,
) -> FieldOperator: ...
```
### Statements
- All values are always fields
- Scalar literals are not allowed and they should always be defined as constant fields
- Broadcasting behavior:
+ Constant fields without dimensions (`field.dims = None`) are compatible with any other field and thus implicitly broadcasted to any required dimension
+ Other fields need to be explicitly broadcasted (for now)
```python
# broadcasting syntax:
...
```
- Supported syntactic elements:
+ Name expressions / symbol definitions
+ Arithmetic operators ( +, -, *, /, **)
+ Shifts
+ Function calls (builtins and user-defined)
+ Return statements
```python=
# Field operator definition example:
def field_op(
a: Field, b: Field, *, const: Field = constant(1.0)
) -> Field:
# Constants
constant_field = constant("1.143234", dtype=float64)
# Arithmetic operations
# Broadcasting behavior: implicit or explicit ?
c = a + b
d = b - c
e = c * d
f = d / e
g = e ** f
h = a - f64("1.11111111111121111111111112")
# Function calls
## builtins should have been imported in the module ??
# from gt import abs, max, min, ... ??
ha = abs(f)
hb = max(f, g)
hc = min(f, g)
hd = mod(f, g)
he = sin(f)
hf = cos(f)
hg = tan(f)
hh = arcsin(f)
hi = arccos(f)
hj = arctan(f)
hk = sqrt(f)
hl = exp(f)
hm = log(f)
hn = log(f, g)
ho = isfinite(f)
hp = isinf(f)
hq = isnan(f)
hr = floor(f)
hs = ceil(f)
ht = trunc(f)
## user-defined functions should have been made visible to
# this operator using external symbol management
m = other_field_op(f, g)
# Shifts:
## - Unstructured offsets
ja = g(E2V)
jb = g(E2V[0])
jc = g(E2V[0])(C2E)
# Neighbor reductions:
## - Simple
e_field = sum(e_e2v_field, axis=E2V)
e_field = sum(e_e2v_field[i] for i in E2V) # ?
#e_e2v_field[i] is a slice of e_e2v_field and thus it is also a field
## - Weighted
local_weights = Field[E2V]
# TODO: Update once broadcasting syntax is defined
e_field = sum(e_e2v_field * weights, axis=E2V)
e_field = sum((e_e2v_field * weights)[i] for i in E2V)
# control flow
out = where(mask, a, b)
# return
return out
```
### Scan passes
- builtin `scan` definition:
```python=
def scan(
scan_pass: Optional[FieldOperator[[OutFieldsT, *OtherFields], OutFieldsT]] = None,
init: Optional[Any] = None
*,
reverse: bool = False,
) -> FieldOperator: ...
def _scan(...):
...
return (
_scan(scan_pass)
if scan_pass is not None
else functools.partial(_scan, init=init, reverse=reverse)
)
```
- application:
```python=
# decorator use
@scan(init: Optional[Any]=None, reverse: bool = False)
def my_scan_pass(state: Optional[OutFieldsT], f: Field) -> OutFieldsT:
return max(f, state) if state is not None else f
# direct use
def other_operator(a: Field, b: Field) -> Field:
# column_op: Callable[[Dims+1, ...], Field[Dims+1]] = scan(my_scan_pass, init=None)
c = scan(my_scan_pass, init=None)(a)
d = scan(my_scan_pass, init=None)(b)
```