# [GT4Py] Custom Field Collections (AKA pytrees) - Shaped by: - Appetite (FTEs, weeks): - Developers: <!-- Filled in at the betting table unless someone is specifically required here --> This task was originally shaped for [cycle 20](https://hackmd.io/team/gridtools?nav=overview&tags=%5B%22cycle20%22%5D) and has been moved here. But it has been revived for [cycle 31](https://hackmd.io/Yw8nkBs9Tgeayxo-zGgujA). ## Problem GT4Py has basic support for collections via tuples. We identified several features that would profit from a general mechanism for collections. The following examples are not meant to be supported exactly with this syntax, but to illustrate the concepts. ### Expandable parameters GridTools has a feature called `expandable parameters` which allows executing an operation over a (dynamically-sized) collection of fields. A use-case in weather and climate applications are tracers, where the same operation is performed for a list of fields. Example: ```python def foo(param: Field[[Dim], Dtype], tracers: ExpandableParameter[Field[[Dim], DType]]): return param * tracers ``` with the semantics (`...` are pseudo-code to express a dynamic list) ```python! def foo(param: Field[[Dim], Dtype], tracer0: Field[[Dim], DType], tracer1: Field[[Dim], DType], ...): return param * tracer0, param * tracer1, ... ``` Note all elements in an expression need to have the same container structure (number of elements in the ExpandableParameter) or be normal parameters (not ExpandableParameters). ### NamedTuple Syntax sugar for long argument lists. Allow to group related fields Example: extraction ```python! class Velocities[DomT, DT]: u: Field[DomT, DT] v: Field[DomT, DT] def foo(velocity: Velocities[Dims[Dim], Dtype]) -> Field[[Dim], Dtype]: return sqrt(velocity.u**2 + velocity.v**2) ``` Example: creation ```python def foo( u: Field[[Dim], Dtype], v: Field[[Dim], Dtype] ) -> Velocities[Dims[Dim], Dtype]: return Velocities(u, v) ``` ### Distinguish Tracers A collection of tracers need to be filtered to apply a certain algorithm only to some of them. Kind of similar in spirit to using [`Literal` values for indexing collections](https://mypy.readthedocs.io/en/stable/literal_types.html#intelligent-indexing) or using [Tagged Unions](https://mypy.readthedocs.io/en/stable/literal_types.html#tagged-unions) in Python typing annotations. ```python! class Tracer: data: Field[[Dims, ...], DType] flag: int def foo(tracers: List[Tracer]) algorithm([tracer for tracer in tracers if tracer.flag == 1]) ``` ### Mesh It's not yet clear how this idea will evolve. Conceptually, a Mesh is just a (named and fixed-size) collection of connectivities, which matches the idea of a NamedTuple used in a very specific case. If this idea is implemented, we could get rid of the confusing difference between `Offset`s and `Offset Providers`, where offset providers _inject_ meaning meaning to offsets using some magic keyword argument when calling programs, by making the `mesh` explicit in the signature of the operators. ```python! def foo(mesh: Mesh[???], foo: Field[[Edge], Dtype]): return foo(mesh.v2el) ``` However, this makes the code inside the operators more verbose since either connectivities are always prefixed with the mesg argument name (`field(E2V)` becomes `field(mesh.E2V)`) or a explicit _unpacking_ is required at the beginning of the operator (`E2V, C2V = mesh.E2V, mesh.C2V]` or `E2V, C2V = mesh['E2V', 'C2V']`). ## Appetite Full 4 week cycle ## Solution With *pytree*s, JAX offers an extensible mechanism that provides this functionality. We would like to implement a similar mechanism for GT4Py with the following features: ### Access mechanisms We want - attribute access for NamedTuple-like containers - element access for list-like containers - possible no unpacking for some containers - pack fields into containers to pass it around or return it inside field operators. TODO: #### Expose type information at compile-time (parsing) To support any user-defined container and not only a specific subset (e.g. NamedTuples, Dataclasses and Attr classes), users should expose the list of attributes and types of custom container at compile-time. We could define an abstract typing Protocol similar to `TypedDict` defining the expected content of the actual container object provided at run-time. This would expose the required typing information when compiling the operator while restricting the actual object with the data passed by the user at run-time. However... ```python class Velocities(gtx.FieldCollection): class StaggeredVelocities(gtx.FieldCollection): u: Field[Dims[I,K_2], float32] v: Field[Dims[J,K_2], float32] u: Field[Dims[I,K], float32] v: Field[Dims[J,K], float32] st: StaggeredVelocities VelocitiesClass = dataclass(Velocities) # Not Yet: ExpandableParameter = Sequence[Field[...]] # Not Yet: tuple[Field[...], Field[...], Field[...]] # Not Yet: ExpandableParameter = Annotated[Sequence[Field[...]], 5] def foo(velocity: Velocities) -> Field[[Dim], Dtype]: return sqrt(velocity.u**2 + velocity.v**2) - sqrt(velocity.st.u**2 + velocity.st.v**2) def foo(velocity: tuple[u, v, st.u, st.v]) -> Field[[Dim], Dtype]: return sqrt(velocity[0]**2 + velocity[1]**2) - sqrt(velocity[2]**2 + velocity[3]**2) ``` #### Expose type information at run-time (compiled and embedded) To extract the expected field data at run-time, we can use a similar mechanism as the [pytree protocol](https://jax.readthedocs.io/en/latest/pytrees.html) used by JAX or the the tree registering mechanism used in [autoray](https://github.com/jcmgray/autoray/blob/a0be3702f37bc64507ae501dd22717d224383ad2/autoray/autoray.py#L504), with the extra feature that it would be possible to verify at run-time that the extracted fields match the expected types. ### Lowering TODO - how do we lower? unpack container to tuples? Where does it happen? ### Creation within the DSL (parsing) The Custom container type need to be accessible including in type-checking ## Rabbit holes <!-- Details about the solution worth calling out to avoid problems --> ## No-gos <!-- Anything specifically excluded from the concept: functionality or use cases we intentionally aren’t covering to fit the ## appetite or make the problem tractable --> ## Progress <!-- Don't fill during shaping. This area is for collecting TODOs during building. As first task during building add a preliminary list of coarse-grained tasks for the project and refine them with finer-grained items when it makes sense as you work on them. --> - [x] Task 1 ([PR#xxxx](https://github.com/GridTools/gt4py/pulls)) - [x] Subtask A - [x] Subtask X - [ ] Task 2 - [x] Subtask H - [ ] Subtask J - [ ] Discovered Task 3 - [ ] Subtask L - [ ] Subtask S - [ ] Task 4