- Shaped by: Hannes, Enrique - Appetite (FTEs, weeks): 2 FTEs, full cycle - Developers: <!-- Filled in at the betting table unless someone is specifically required here --> ## Problem The current embedded execution of `gt4py.next` programs use the embedded implementation of the Iterator IR primitives, which means all the looping and shifting is done in python and thus is painfully slow. This approach implies that only small domains are viable for complex programs, and it also makes debugging the DSL code and looking at the intermediate results very hard, because it executes all the intermediate steps for every point in the domain, instead of executing each operation for the whole domain. We would like to address the current issues with a new embedded execution implementation which ideally would allow: - fast Python execution, at the speed of NumPy, CuPy, JAX or similar ndarray libraries, without parsing or applying transformation to the DSL code written by the user. - debugging user-code by stepping through the application with regular Python debugging tools. ## Appetite Full cycle with a team of 2 or more as this feature is the basis for several improvements (see optional research tasks) and requires solid testing. If time is left there are optional research tasks at the end of this document which could result in shaped projects for the next cycle. ## Solution We propose to implement the embedded execustion of DSL code as a lightweight interface on top of [data array](https://data-apis.org/array-api/latest/) libraries. The design should be simple and flexible enough to support additions and extensions interacting with the embedded execution (e.g. `grad` or `vmap` JAX-style function transformations). The current solution will reuse as much as possible design ideas and knowledge gathered from previous experiments like Enrique's Embedded Field-View ([EEFV](https://github.com/egparedes/gt4py_embedded/tree/gtnextv2-dispatch)) or components used in other models like Till's implementation of index spaces for FVM ([TIIS](https://github.com/ckuehnlein/FVM_GT4Py/blob/main/src/fvms/utils/index_space.py)). ### Domain-related concepts - **Dimension** might be conceptually defined as a set of (ordered) tagged coordinates representing a degree of freedom of a space. - **Coordinate** are indices associated to a specific a **Dimension**. - **Domain** (or **CartesianSet**? in TIIS) is conceptually a set of Dimensions (that is, the order does NOT matter) although currently gt4py treats the domain as a cartesian product (tuple) of dimensions (order DOES matter). In EEFV `Dimension` was implemented as an actual class having a custom `DimentionType` metaclass to add extra sequence-like functionalities. If the `Dimension` is bounded, it defines a valid `range` of coordinates. Subranges are easily represented as subclasses with a custom `__subclasscheck__` method. Another possibility is to define a `Dimension` as a simple `(tag, range)` tuple-like value and coordinates as `(tag, index)` values, as in TIIS. Both approaches should work fine. ### Field implementation A **Field** is conceptually a profunctor supporting `fmap` and `contramap`/`lmap` operations. The `fmap` operation is not directly exposed since we only allow to map a certain number of builting (primitive) operations defined for numeric values (e.g. arithmetic ops, sin, cos, log, ...). Therefore, we can define the Field concept as a Protocol supporting: - arithmetic operators - supported math functions (the whole list of `gtx.builtins`) - `remap()` (equivalent to `contramap`/`lmap`) - `domain` property - `buffer` (or `ndarray`) property to expose the content as a memory buffer - Dispatching mechanism for the implementation of the other mathematical primitive operations - Optionally, expose (`__array__`, `__cuda_array_interface__` and ` __dlpack__` methods to interact with other libraries) `Connectivity` fields (fields of indices) should additionally define an `inverse_image()` function to support ing the domain of the connectivity according to the domain of the remapped field (see [_remap_](https://hackmd.io/D2gZEUINQ3SVaylITZqEQA?both#remap-aka-unstructured-shift-implementation) and EEFV for details). The actual implementation of the Field protocol using data arrays libraries could just wrap a data buffer and implement the required operations with calls to the data array library. To keep track of the mapping between the domain and the buffer data, it should implement: - domain: describes the current domain of the field - ndarray: a view of the current the buffer sliced with the current domain - extended_buffer: wrapped data array - extended_domain: the extended domain of the original buffer <!-- This is not completely correct since shifting per-se doesn't really shinks the domain, only the intersection with another field after shifting. Probably needs a bit more of discussion Explanation: - A newly constructed field will typically have domain==extended_domain after construction. Shifting in one direction will shrink the view_domain, but keep the buffer_domain. This allows that shifting in the opposite direction will again produce the original Field. --> ### Operations An `Operation` is a primitive builtin function operating with fields (and scalars values) provided by gt4py. To support advanced used cases (e.g. transformations) in the future, the implementation should use custom implementations of the operation if provided by the argument types. A pseudo-code example of how this idea could be: ```python! class BuiltinOp(Generic[R, P]): default_impl: Callable[P, R] def __call__(self, *args: Value, **options: Any) -> Value: actual_impl = self.dispatch(*args) return actual_impl(*args, **options) def dispatch(self, *args: Value) -> Callable[P, R]: arg_types = tuple(type(arg) for arg in args) for atype in arg_types: if atype.has_impl(self): return atype.get_impl[self] else: return self.default_impl ``` Existing `gtx.builtins` will need to be adapted to this idea. #### `remap` (aka unstructured shift) implementation `remap` basically _contramaps_ a connectivity field on top of a regular data field. ```python! def remap( data_field: Field[[B,C,D], scalar], connectivity: Field[[A], B] ) -> Field[[A,C,D], scalar]): ... ``` The algorithm can be divided in three main steps: 1. Compute the new dimension range: this only depends on the domain of the data field and the connectivity and thus the result should be cached. Note that the remap operation is only valid if the new dimension range is contiguous. 2. Restrict the domain of the connectivity to the dimension range computed in the previous step. 3. `contramap` the restricted connectivity, which is equivalent to use NumPy _advanced indexing_ on the data field with the connectivity (also `np.take()`). #### _fmapping_ data operations (binary or ternary) with automatic domain intersection A sketch of the algorithm could be: 1. compute intersection of view_domains, let's call it `intersection_domain` 2. construct buffers corresponding to intersection_domain, i.e. slice buffer with intersection_domain to get `view_buffers` 3. get new_buffer = buffer (numpy) operation applied on view_buffers (e.g. binary operations or `where`) 4. create new field with buffer_domain = view_domain = intersection_domain and new_buffer as buffer ### `slice` notation and semantics Slicing a field (AKA `__get_item__`) returns a new field defined only in a subset of the original field but with exact the same values. The restriction can be specified as an absolute or relative subdomain. - Absolute: - `field_vk[K(-2):K(20)]` ? - `field_vk[V[0:-1], K[-2:20]` ? - `field_vk.at[K(-2):K(20)]` ? - Relative: - `field_vk[2:-2, 3:-3]` == `field_vk.at_dims(K, V)[2:-2, 3:-3]` ? ### Implement non-overlapping concat Add a new builtin `concat` that takes Fields with non-overlapping (view-)domain and combines them. The requirement is that the union of the domains is again a hypercube. Due to the change in `where` semantics described in https://hackmd.io/fCXnShnFR96kFau7lw36ew#Open-field-view-questions, `concat` is the only way to implement boundary conditions. Therefore, user-code has to be changed from `where` to `concat` in these cases. (Note: in itir the compute domain is given by the `domain` argument in the stencil_closure, therefore lowered `where` will behave as before in un-checked gtfn execution. However, if we execute in field view embedded or enable (a hypothetical) domain size checking, the domain after `where` will be too small or even empty (if the boundary field is only defined on a slice) as it will take the intersection of all 3 arguments). Lowering `concat` to ITIR requires knowledge of the domain of each participating field, e.g. ```python! concat(interior, boundary) ``` should be lowered to something like (simplified) ```python! if_(interior_lower <= index < interior_upper, interior, if_(boundary_lower <= index < boundary_upper, boundary, NaN)) ``` where `xxx_lower` and `xxx_upper` are the domain bounds (for simplicity here in 1D). However these bounds are not accessible in ITIR, they would have to be added as extra (possibly symbolic) parameters. This might drastically complicate the implementation of `concat`, see rabbit ho ### Testing - Add as new backend and run all tests with embedded NumPy-mode (possibly other embedded backends, e.g. CuPy) - Rewrite tests in `test_foast_to_itir.py` to execute first the field view program (`testee`) in embedded field view to compute the reference, then run the lowered itir program in itir.embedded. <!-- - `k_top = field({K=K_MAX-1:K_MAX}, 0.)` - `extend(x, TODO(K_MAX-1:K_MAX))` == `concatenate(x, k_top)` - `x.at_dims(K)[K_MAX-1:K_MAX].set(k_top)` - `x.at_dims(K)[K_MAX-1:K_MAX].set(k_top)` - `concatenate(x.at_dims(K)[:KMAX-1], k_top[KMAX-1:KMAX], x.at_dims(K)[KMAX:])` - lowers to `if_(K_MAX-1 <= k_index <= K_MAX, k_top, x)` --> ## Rabbit holes ### `concat` As described above, `concat` lowering might be complicated due to passing the domain argument. Therefore, in case no better solution is found in time, we will only allow the subset of `concat`s where at most 1 argument has non-absolute (or no) slicing. In the example from above with slicing ```python! concat(interior, boundary[:D(1)]) ``` implies that `interior` is used for `index(D)>=1` (as concat is non overlapping). Therefore we can lower directly to ```python! if_(index < 1, boundary, interior) ``` ## No-gos - FieldIR is excluded. - Don't support jax jitting. ## Optional research - concat: - full concat support: Explore what's needed for a (clean!) support of `concat` without limitations in the lowering to ITIR - jax field.at(...).set(...) style for combining fields with precedence - could be implemented as a library function with slicing an non-overlapping `concat`, but maybe for jax compatibility we might want to make it a builtin - Sketch the FieldIR - how to implement ITIR embedded column mode with the embedded Field implementation (see https://github.com/GridTools/gt4py/pull/1283 and https://github.com/GridTools/gt4py/pull/1141 for problems in the current implementation -> ITIR column mode requires the same intersection semantics as field view) ## Tasks - [x] Add minimal implementations of UnitRange and CartesianSet to gtx.common - [x] UnitRange (Sam) - [x] Domain and slicing - [ ] (refactor Domain based on CartesianSet) - [x] Add skeletons of Field and BuiltinOp from Enrique's prototype to a new draft PR (Enrique) - [ ] Implement gamma function - [ ] Replace LocatedField in iterator.embedded with new Field implementation (Hannes) - [x] Add `__gt_dims__` and `__gt_origin__` to Field protocol - [x] In ffront.program execution extract the domain from new Field - [x] with Enrique: discuss behavior of slicing (`__get_item__`) - [ ] Discuss: Implement a wrapper field that translates N dimensional array to N-m and makes m dimensions tuple dimensions - [ ] Add implementations for all builtinops currently supported in field-view frontend - [ ] Add support for Cartesian shift (Hannes) - [ ] Add support for Unstructured shift - [ ] Add support for where with tuples - [ ] Implement `broadcast` for scalars: requires a proper implementation of ConstantField - [x] Implement intersection of domains for field operations - [x] Implement slicing of fields - [ ] Finalize specification of `concat()` - [ ] Implement `concat()` - [ ] in embedded - [ ] in lowering to itir ## Discussions ### slicing 2 use-cases: - inside of the DSL - outside We want: - for now relative slicing with integer slices (`tuple[int|slice, ...]`), relative to the domain (not the buffer, but for ndarray it's the same) - absolute slicing with `Region` aka `NamedIndicesOrRanges` - note that `tuple[int|slice, ...]` includes `f[1:10]` (i.e. relative from start or absolute depending on definition of absolute) We postpone the decision what's RegionLike, e.g. do we want something like `(I[1:10], J[2:4])` to construct a Region. ```python NamedRange: TypeAlias = tuple[Dimension, UnitRange] NamedIndex: TypeAlias = tuple[Dimension, int] NamedIndicesOrRanges = Sequence[NamedRange | NamedIndex, ...] # find better name? SliceLike = NamedRangesOrIndices | tuple[slice|int,...] @overload def __getitem__(self, i: NamedIndicesOrRanges) -> Field: """ Absolute slicing with dimension names. """ ... @overload def __getitem__(self, i: tuple[slice|int,...]) -> Field: """ Relative slicing with ordered dimension access. """ ... @overload def __getitem__(self, i: tuple[int,...]) -> Value: @overload def __getitem__(self, i: Sequence[NamedIndex]) -> Field | Value: # Value in case len(i) == len(self.domain) ... def __getitem__(self, i: SliceLike) -> Field | Value: # actual implementation ``` #### examples ```python! f: Field[[I:(1, 3), J:(2,5)], tuple[i_idx, j_idx]] f.ndarray == [[(1,2), (1,3), (1,4)],[(2,2), (2,3), (2,4)]] f.shape == (2,3) ``` ##### absolute ```python f[0,0] == f[I[1], J[2]] f[I[0], J[0]] # out of bounds ``` ##### relative pseudocode ```python! fa: int = f[1,0] == (2,2) f0: Field[[J:(2,5)], ...] = f[1] f0.ndarray == [(2,2), (2,3), (2,4)] f0.shape =(3,) f1: Field[[I:(1,3)], ...] = f[0:2, 1] f1.ndarray == [(1,3), (2,3)] f1.shape == (2,) f2: Field[[], ...] = f[:, 1] == f1 f3: int = f[-1, -1] == (2,4) f4: Field[[I:(1,3)],...] = f[..., 2] f4.ndarray == [(1,4), (2,4)] f5: = f[:] == f[...] # TODO check numpy f6 = f[1,:] # etc... f7: Field[[J:(3,4)], ...] = f[1, 1:-1] f8: Field[[J:(3,4)], ...] = f[1, -2:-1] ``` ### scratchpad ```python class Dimension: def __getitem__(s: slice)->NamedRange: ... RegionLike = Region | tuple(NamedRange| NamedIndex, ...) NamedSliceOrIndexLike DomainConstructible = class Domain: @classmethod def from(self, tuple[NamedRange,...]) ... def from(self, dict[Dimension, UnitRange]): ... def from(self, **kwargs: Tuple[Dimension, UnitRange]): ... field_vk.at(I=(0, 10)) field_vk.at(I[0:10]) # line below has the same problems but nicer syntax (therefore exclude this) field_vk[I[0:10], J[0:10]] # this would be nice I(1) I(0:10) # unfortunately error I[1] # reserved for offset f(I+1) f[I[1:10]] I[1] I[0:10] f(I[1]) f[I[1]] I(slice(0, 10)) NamedRangeOrIndex = def __getitem__(self, *args: NamedRangeOrIndex) -> Field: ... ``` ```python! f[i:j] == f[Domain(I:(f.domain[I].start+i, f.domain[I].start+j))] for i,j>0 something else for negative ``` might not be the same as `f.ndarrray[i:j]` #### ideas - Absolute: - `field_vk[K(-2):K(20)]` ? - `field_vk[V[0:-1], K[-2:20]` ? - `field_vk.at(K(-2):K(20))` ? - Relative: - `field_vk[2:-2, 3:-3]` == `field_vk.at_dims(K, V)[2:-2, 3:-3]` ? ```python! I = Dimension("I") yo class Dimension: def __add__(self, step: int) -> Connectivity: return FieldOffset("I", source=self, target(self,))[step] inp(I + 1) inp(I[1]) def lap(inp): return inp(Ioff[1]) return inp(I + 1) # remap inp by one # return inp(I[1]) # access inp at (absolute) index I=1 C2E = tuple[Connectivity,...] Khaldim_neighbors = tuple[K2Khalf[-1], K2Khalf[+1]] lap_conn = tuple[I+1, I-1, J+1, J-1] def lap(inp): neighbor_sum(inp(lap_conn), axis=implicit) inp[:, C2E(0)] C2E[0:10] # -> NamedRange inp(C2E[0])t v: Field[Vertex] vv=v(V2V): Field[Vertex, V2VDim] vv(V2V) # error V2VDim twice vvv: Field[Vertex, V2V, WrapDim[V2V,1]] = vv.remap(V2V, newdim=WrapDim(V2VDim, 1)) ```