# Tensor IR ###### tags: `IR` `Felix’s remains` This document introduces _tensor IR_, a higher-level alternative or companion for _iterator IR_. ## Table Of Contents [ToC] ## Motivation The iterator IR proved to be a large step forward in terms of analyzability and transformability compared to previous (non-pure) stencil IRs. Nevertheless, it still comes with some weaknesses and missing definitions in some areas. Some of those are: - Domain size handling is largely undefined and hard to fix. - Dependency analysis with nested lifts and shifts etc. is hard. - Thus, correct domain extent analysis is very hard to impossible. - The transformation for popping up $\mathrm{lift}$ calls is utterly complex and requires handling of special cases for builtins. - $\mathrm{can\_deref}$ of shifted lifted stencils is still not understood. - The combination of untyped IR with type-dependent overloads makes analysis sometimes complex. - Abstract offsets do not make sense: most transformations require knowledge of the offset types to work, thus this abstraction is not useful and complicates transformations. - Handling of iterators of tuples vs. tuples of iterators is tricky and can restrict composability. - Iterator IR is very low-level: iterators are basically an abstraction of pointers. Computations on pointers are very hard to optimize and all modern optimization frameworks employ much higher abstraction levels where possible (MLIR, DaCe, …). - Transformation to higher-level representations like DaCe is complex. - Data-flow analysis that includes the _amount_ of data — which is important for any reasonable performance analysis — is very complex due to the single-element view. - Vectorization and blocking are difficult to implement (or impossible, because knowledge is missing in the IR) — again due to the single-element view. - The users placement of $\mathrm{lift}$ calls influences where temporary buffers can be used and thus performance. This contradicts the idea of separating the concerns of performance and calculation. Note that more complex transformations could probably insert more lifts, but it is hard to guess why this would make sense due to previously listed limitations. The proposed tensor-based IR tries to fix these issues. First, it is fully typed, does not include any abstract offsets but uses explicitly typed offsets and gets rid of the complex and possibly optimization-limiting lifts. Further, it is also much closer to modern tensor-based libraries such as Jax, PyTorch and DaCe. Last but not least, it is close to the field-view frontend which is already the user frontend of GT4Py. ## Main Commonalities with and Differences to Iterator IR | | Iterator IR | Tensor IR | |-------------|------------------------------------------------|--------------------| | Pure functional ‘stencils’ | Yes | Yes | | Non-pure, procedural ‘fencils’ | Yes | Yes | | Abstraction | Low-level, pointer-like | High-level, array-like | | Typing | Untyped[^1] | Fully, explicitly typed | | Basic Types | Iterators, Columns, Scalars, Tuples, Functions | Tensors, Functions | | Special Handling of Columns | Required | None | | Position[^2]| Implicit, through positional iterators | Explicit | | Lifting | Required | Not required | | Tmp. Granularity[^3] | Requires user-placed _lifts_ | Any expression | | Tmp. Complexity[^4] | Requires complex transforms and analysis | Simple | | Offsets | Abstract, not part of IR | Explicit | | Tuple Handling | Ambiguous (tuple of iterators vs. iterator of tuples) | Unique | | Data Flow | Hard to track (lifts, derefs, unknown data size) | Trivial to track | | Domain specification | Unclear, implicit | Part of IR, explicit | [^1]: There has always been the notion of ‘iterators’ and ‘values’, but no formal definitions were available in the beginning of iterator IR and the current GT4Py implementation is based on untyped lambda calculus. [^2]: ‘Position’ refers to the n-dimensional index where a computation happens. [^3]: This refers to the granularity of computations that can be placed into temporary data buffers (without requiring additional transformations). [^4]: Iterator IR requires a very complex transformation (‘PopupTmps’ in GT4Py) as a first step for storing results of lifted functions into temporary buffers. Further, the analysis for finding the appropriate buffer sizes and domain sizes is again pretty complex. ## Rationale Here, justification for all design changes compared to iterator IR is given. Well-discussed topic like using pure functional computations ### Pointer-Level Abstraction vs. Tensor-Level Abstraction Iterator IR is heavily influenced by GridTools C++ and its SID concept. SID is an abstraction of pointers and pointer arithmetic for n-dimensional strided data. This is well suited for implementing low-level array-based computations and has proven to be a useful abstraction on the GridTools C++ level. Iterators in the iterator IR are almost identical. The main power gain of iterator IR’s iterators compared to SID is the possibility to ‘lift’ stencil functions to return iterators instead. This allows to compose multiple stencils. Lifts can be replaced by inline computations or by computations into temporary buffers. The former is a trivial transformation, but the latter transformation is very complex and additionally requires analysis of dependencies and domain size. This complexity comes from two sources: first, from the requirement of lifting a value-based computation to a pointer-based one and second, from the inherent complexity of pointer-based data flow analysis. Recent developments in machine learning brought much attention to tensor-based computations, including stencil-like patterns in convolutional neural networks. All modern machine learning frameworks and standards like [TensorFlow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), [array API](https://data-apis.org/array-api), and lower-level tools like [XLA](https://www.tensorflow.org/xla) and [MLIR](https://mlir.llvm.org/) use an array/tensor-based abstraction, none uses a pointer-based one. All these tools try to lift the abstraction level compared to classical compilers, as this is required to apply domain-specific optimizations. Also [DaCe](https://github.com/spcl/dace) — which is not solely focused on machine learning — chose the same route. Further, the problem of proper domain size specification and dependency analysis is a big unsolved problem in iterator IR; but main problems to solve in the future, like automatic halo exchanges and Cartesian boundary conditions in combination with temporaries require a solution. Last but not least, on the array-based abstraction level, no lifts and derefs are required, which simplifies the transformations required to introduce global temporary buffers and various analysis passes and the complex handling of can_derefs is greatly simplified. ### Typing Iterator IR is based on untyped lambda calculus. The type system was developed at a later stage. Nevertheless, the IR included overloads of functions, that require an argument type to be correctly resolved. From the beginning, it was clear that we require iterators, scalars, columns and functions. But the types were never well-defined; for example, it is unclear if the location type of an iterator should be part of the type or what exactly a column is and how its size is defined (fixed, part of the type? dynamic?). This leads to additional difficulties, like: where the scan function should know the column size from, do we support columns of different size in a single stencil or fencil? Time has shown that many analysis passes require the type of function arguments. Currently, we often play some tricks to guess the type or just assume something. For example, our reduction unrolling pass requires a partial shift to be present, even if it could work on any partially shifted field in theory. Further, types encoded in the IR give additional safety and lead to earlier user error messages. The tensor computation frameworks mentioned earlier also use typed intermediate representations. Last but not least, specification of the compute domain size and extent analysis have posed major problems in iterator IR that are still rather unsolved. Thus, in tensor IR, all types are well-defined and all expressions are explicitly typed. Additionally, the size of a tensor is part of the type. The sizes of involved tensors define unambiguously the length of the scan axis; different column heights are easily supported within a single stencil or fencil. ### Abstract Offsets vs. Explicit Offsets On the one hand, in iterator IR, offset types are ‘abstract’, that is, not part of the IR. Offsets are only understood by an ‘offset provider’. On the other hand, paradoxically, the IR supports partial and full (unstructured) shifts, but it’s impossible to to decide within the IR (that is, without offset provider), which is which, as this would require knowing the types of the offsets. Many transformations also require explicit knowledge of the offset types (to distinguish Cartesian from unstructured shifts; find partial shifts). Thus, in the current implementation, the offset provider is a required argument to many transformations. Further, this abstraction complicates the fencil interface design in the backends, as it is unclear, how these abstract dimensions and neighbor lookup tables should be handled. Also, while the ‘uniform’ handling of Cartesian and unstructured shifts seems attractive at first sight, this does not make much sense at a closer look: both require totally different dependency analysis; for neighbor reductions, we need partial unstructured shifts, but partial Cartesian shifts are not needed and do not make sense as we can not reduce of the infinite set of Cartesian neighbors. So a distinction by either type or function name (or both) should be encoded in the IR. Thus, in tensor IR, all offsets are part of the IR. Cartesian shifts require a dimension name and integer literal offset; unstructured shifts require a neighbor lookup table in the form of a tensor and an integer literal offset (or just one of both for partial shifts). Note that Cartesian shifts with dynamic/run-time offsets can be modeled by unstructured shifts with dynamically computed neighbor tables in tensor IR. ### Tuple Handling Another area where confusion arises over and over again is the handling of iterator of tuples vs. tuple of iterators. Those two are very similar, but require a different sequence of accessing (deref then tuple_get, vs. tuple_get than deref) and are thus incompatible. It is possible to convert between them within iterator IR, but this requires additional lifted stencils, leading to complex dependency analysis (and tricky lowering passes from higher-level abstractions if this should be done automatically). Thus, tensor IR does not support general tuples, but only supports tuple types as element types within tensors. That is, it eliminates one of the two incompatible views. This knowingly disallows some constructs that are theoretically allowed in iterator IR (but not in the current type checker), for example tuple of functions. But such a use case has not appeared in any of the stencil examples or ported stencils from ICON and this feature is thus viewed as superfluous. ### Implicit vs. Explicit Position In iterator IR, an iterator does not know its ‘global’ position. For computations that require a position, positional iterators can be used; that is, iterators that just store an index along one dimension. But for some computations supported by iterator IR, namely unstructured shifts, a global index is required as it specifies the lookup location. Further, an explicit location makes some computations easier to understand as it clarifies which values belong together and finally simplifies dependency analysis and domain computations significantly. Thus, tensor IR defines dimensions and valid coordinate intervals for all tensors, similar to [xarray](https://xarray.dev). This enables intuitive broadcasting rules and allows to handle neighbor tables as normal tensors as all required metadata can be encoded in the dimension names and intervals. ## Data Types ### The Tensor Type Tensor type: $\mathrm{Tensor}[T, D…]$, where $T$ is the element type (e.g. float), and $D…$ are any number of dimension definitions. A dimension definition is of the form $name: [start\,..\,stop]$, where $name$ is the name of the dimension and $[start\,..\,stop]$ defines an integer interval (where this tensor provides values). ### The Function Type The function type is of the form $(\mathrm{A}, …) → R$, representing a function, mapping from one or more argument values to a single return value. ### The Tensor Element Type The tensor element type can be either a scalar element type (e.g., $\mathrm{float64}$) or a possibly nested tuple element type (e.g., $\mathrm{(float32, int64)}$). The backend can freely decide about the data layout of tensors with tuple elements during lowering (e.g. tuple of tensors, tensors of tuple, mixed). ### Literals For function parametrization, literal types are required. Alternatively, these builtin functions could also be pushed to the grammar level. ## High-Level Description and Examples of Tensor IR The main data type of tensor IR is the parametric tensor as defined above. Each tensor has an element type and a set with any number of dimensions with (limited) extents. For example $\mathrm{Tensor}\bigl[\mathrm{int}, \{\mathrm{x}: [-3\,..\,5]\}\bigr]$ is a one-dimensional tensor, defined on a dimension $\mathrm{x}$ in the interval $[-3\,..\,5]$. Broadcasting rules follow the idea, that resulting tensors are defined on the union of all dimensions, but only on the intersection of respective intervals. For example, if we take the tensor from above and multiply it with a 2D tensor of type $\mathrm{Tensor}\bigl[\mathrm{int}, \{\mathrm{x}: [1\,..\,9], \mathrm{y}: [5\,..\,8]\}\bigr]$, the resulting 2D tensor will have the type $\mathrm{Tensor}\bigl[\mathrm{int}, \{\mathrm{x}: [1\,..\,5], \mathrm{y}: [5\,..\,8]\}\bigr]$. Note that taking the union of all dimensions is compatible to the SID/iterator IR view where shifting along an undefined axis is a no-op, while the intersection of intervals make sure that we do not compute on undefined values. Cartesian shifts simply alter the interval of the given dimension. Unstructured shifts perform an indirect lookup and also modify the dimensions. Neighbor tables are normal tensors, with specially tagged dimensions. See the following section for a detailed description of all builtins. Here some basic Laplacian fencil as a first example. Note the similarity to iterator IR, but all lifts and derefs are gone. Further, the argument types are annotated: ``` laplacian( out: tensor<float, IDim[0:5], JDim[0:7], KDim[0:9]>, inp: tensor<float, IDim[-1:6], JDim[-1:8], KDim[0:9]> ) { out ← ( λ(inp: tensor<float, IDim[-1:6], JDim[-1:8], KDim[0:9]>) → ⟪iₒ, -1ₒ, iₒ, 1ₒ, iₒ, -1ₒ⟫(inp) - ⟪iₒ, -1ₒ, iₒ, 1ₒ⟫(inp) - (⟪iₒ, 1ₒ, iₒ, -1ₒ⟫(inp) - ⟪iₒ, 1ₒ⟫(inp)) + (⟪jₒ, -1ₒ, jₒ, 1ₒ, jₒ, -1ₒ⟫(inp) - ⟪jₒ, -1ₒ, jₒ, 1ₒ⟫(inp) - (⟪jₒ, 1ₒ, jₒ, -1ₒ⟫(inp) - ⟪jₒ, 1ₒ⟫(inp))) )(inp); } ``` The following shows the classic tridiagonal solver. Note that there is no special column type, everything just works on tensors: ``` solve_tridiag( a: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]>, b: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]>, c: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]>, d: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]>, x: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]> ) { x ← (λ( a: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]>, b: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]>, c: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]>, d: tensor<float, IDim[0:3], JDim[0:7], KDim[0:5]> ) → __scan( λ(x_kp1: tensor<float, IDim[0:3], JDim[0:7]>, cpdp: tensor<(float, float), IDim[0:3], JDim[0:7]>) → cpdp[1] - cpdp[0] × x_kp1, False, 0.0 )(__scan(λ( state: tensor<(float, float), IDim[0:3], JDim[0:7]>, a: tensor<float, IDim[0:3], JDim[0:7]>, b: tensor<float, IDim[0:3], JDim[0:7]>, c: tensor<float, IDim[0:3], JDim[0:7]>, d: tensor<float, IDim[0:3], JDim[0:7]> ) → __make_tuple(c / (b - a × state[0]), (d - a × state[1]) / (b - a × state[0])), True, __make_tuple(0.0, 0.0))(a, b, c, d)))(a, b, c, d); } ``` In the following, the FVM nabla fencil is shown. Note how all lifts and derefs are gone, and the explicit passing of neighbor lookup tables as normal tensors: ``` nabla( n_nodes: tensor<int>, out: tensor<(float, float), Vertex[0:5440]>, pp: tensor<float, Vertex[0:5440]>, S_MXX: tensor<float, Edge[0:16167]>, S_MYY: tensor<float, Edge[0:16167]>, sign: tensor<float, Vertex[0:5440], _NB_0[0:7]>, vol: tensor<float, Vertex[0:5440]>, E2V: tensor<int, Edge[0:16167], _NB_Vertex[0:2]>, V2E: tensor<int, Vertex[0:5440], _NB_Edge[0:7]> ) { out ← ( λ( pp: tensor<float, Vertex[0:5440]>, S_MXX: tensor<float, Edge[0:16167]>, S_MYY: tensor<float, Edge[0:16167]>, sign: tensor<float, Vertex[0:5440], _NB_0[0:7]>, vol: tensor<float, Vertex[0:5440]> ) → __make_tuple( __reduce(λ( acc: tensor<float, Vertex[0:5440]>, a_n: tensor<float, Vertex[0:5440]>, c_n: tensor<float, Vertex[0:5440]> ) → acc + a_n × c_n, 0.0)( ⟪V2E⟫((λ(pp_: tensor<float, Vertex[0:5440]>, S_M: tensor<float, Edge[0:16167]>) → S_M × (0.5 × (⟪E2V, 0ₒ⟫(pp_) + ⟪E2V, 1ₒ⟫(pp_))))(pp, S_MXX)), sign ) / vol, __reduce(λ( acc: tensor<float, Vertex[0:5440]>, a_n: tensor<float, Vertex[0:5440]>, c_n: tensor<float, Vertex[0:5440]> ) → acc + a_n × c_n, 0.0)( ⟪V2E⟫((λ(pp_: tensor<float, Vertex[0:5440]>, S_M: tensor<float, Edge[0:16167]>) → S_M × (0.5 × (⟪E2V, 0ₒ⟫(pp_) + ⟪E2V, 1ₒ⟫(pp_))))(pp, S_MYY)), sign ) / vol ) )(pp, S_MXX, S_MYY, sign, vol); } ``` ## Builtins This section shows all tensor IR builtin functions compared to the respective definitions in iterator IR. To describe those builtins, we first introduce some operations on intervals and dimensions. First, we use the intersection of intervals: $$ [start_1\,..\,stop_1] \cap [start_2\,..\,stop_2] = [\mathrm{max}(start_1, start_2)\,..\,\mathrm{min}(stop_1, stop_2)]. $$ Then, we define the following recursively defined broadcasting function: $$ \mathrm{bcast}(x) = \begin{cases} \mathrm{bcast}(\{name: i_1 \cap i_2\} \cup r) & \text{if } x = \{name: i_1, name: i_2\} \cup r, \\ x & \text{otherwise.} \end{cases} $$ That is, broadcasting keeps all available dimension names and takes the intersection of all intervals defined per dimension. ### Mathematical & Logical Builtins All mathematical builtins work by broadcasting the input dimensions. The signatures are as follows: $$ \begin{align} \mathrm{and} &:: (\mathrm{Tensor}[\mathrm{bool}, D_1], \mathrm{Tensor}[\mathrm{bool}, D_2]) → \mathrm{Tensor}[\mathrm{bool}, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{divides} &:: (\mathrm{Tensor}[T, D_1], \mathrm{Tensor}[T, D_2]) → \mathrm{Tensor}[T, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{eq} &:: (\mathrm{Tensor}[T, D_1], \mathrm{Tensor}[T, D_2]) → \mathrm{Tensor}[\mathrm{bool}, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{greater} &:: (\mathrm{Tensor}[T, D_1], \mathrm{Tensor}[T, D_2]) → \mathrm{Tensor}[\mathrm{bool}, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{if} &:: (\mathrm{Tensor}[\mathrm{bool}, D_1], \mathrm{Tensor}[T, D_2], \mathrm{Tensor}[T, D_3]) → \mathrm{Tensor}[T, \mathrm{bcast}(D_1 \cup D_2 \cup D_3)], \\ \mathrm{less} &:: (\mathrm{Tensor}[T, D_1], \mathrm{Tensor}[T, D_2]) → \mathrm{Tensor}[\mathrm{bool}, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{minus} &:: (\mathrm{Tensor}[T, D_1], \mathrm{Tensor}[T, D_2]) → \mathrm{Tensor}[T, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{multiplies} &:: (\mathrm{Tensor}[T, D_1], \mathrm{Tensor}[T, D_2]) → \mathrm{Tensor}[T, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{not} &:: (\mathrm{Tensor}[\mathrm{bool}, D_1]) → \mathrm{Tensor}[\mathrm{bool}, \mathrm{bcast}(D_1)], \\ \mathrm{or} &:: (\mathrm{Tensor}[\mathrm{bool}, D_1], \mathrm{Tensor}[\mathrm{bool}, D_2]) → \mathrm{Tensor}[\mathrm{bool}, \mathrm{bcast}(D_1 \cup D_2)], \\ \mathrm{plus} &:: (\mathrm{Tensor}[T, D_1], \mathrm{Tensor}[T, D_2]) → \mathrm{Tensor}[T, \mathrm{bcast}(D_1 \cup D_2)]. \\ \ \end{align} $$ ### Lift and Deref The iterator IR’s $\mathrm{lift}$ and $\mathrm{deref}$ builtins do not have any corresponding function in tensor IR and can be replaced by no-ops. ### Tuples The builtins $\mathrm{make\_tuple}$ and $\mathrm{tuple\_get}$ produce, respectively access, tensors with tuple element types. They have the following signatures: $$ \begin{align} \mathrm{make\_tuple} &:: (\mathrm{Tensor}[T_1, D_1], \mathrm{Tensor}[T_2, D_2], …) → \mathrm{Tensor}[(T_1, T_2, …), bcast(D_1 \cup D_2 \cup …)], \\ \mathrm{tuple\_get} &:: (\mathrm{Tensor}[(…, T_i, …), D_1], i) → \mathrm{Tensor}[T_i, D_1]. \\ \end{align} $$ ### Cartesian Shifts Cartesian shifts require a dimension name and an integer literal. Cartesian shifts change the interval of the shifted tensor (but not the data). $$ \begin{align} \mathrm{cartesian\_shift} ::\; & (name, i) → \\ &(\mathrm{Tensor}[T, \{name: [start, stop], …\}]) → \mathrm{Tensor}[T, \{name: [start + i, stop + i], …\}] \end{align} $$ Note: there is no need to support run-time offsets, for this case, unstructured shifts can be used with computed neighbor tables. ### Unstructured Shifts There are two functions for unstructured shifts: the first initiating the shift (shifting to all neighbors, alias partial shift), the second one applying the shift (selecting one specific neighbor). The functions have the following type signatures: $$ \begin{align} \mathrm{shift\_init} ::\; & (\mathrm{Tensor}[\mathrm{int}, \{dst: i_{dst}, \mathrm{NB\_}src: [0, n]\}]]) →\\ &(\mathrm{Tensor}[T, \{src: i_{src}, …\}]) → \mathrm{Tensor}[T, \{dst: i_{dst}, \mathrm{NB}_x: [0, n], …\}], \\ \mathrm{shift\_apply} ::\; & (j) → \\ &(\mathrm{Tensor}[T, \{dst: i_{dst}, \mathrm{NB}_x: [0, n], …\}]) → \mathrm{Tensor}[T, \{dst: i_{dst}, …\}]. \end{align} $$ The full shift is just the combination of these functions, that is: $$ \mathrm{unstructured\_shift}(table, i)(x) = \mathrm{shift\_apply}(i)(\mathrm{shift\_init}(table)(x)). $$ The signature of this functions is the following: $$ \begin{align} \mathrm{unstructured\_shift} ::\; & (\mathrm{Tensor}[\mathrm{int}, \{dst: i_{dst}, \mathrm{NB\_}src: [0, n]\}], j) →\\ &(\mathrm{Tensor}[T, \{src: i_{src}, …\}]) → \mathrm{Tensor}[T, \{dst: i_{dst}, …\}]. \end{align} $$ There are multiple important points to note here: - Neighbor tables are ordinary tensors with integer elements. - The neighbor dimension $src$ of the neighbor table tensor is tagged with a special prefix $\mathrm{NB\_}$. This tag marks the correspondence between the neighbor table’s neighbor dimension with the $src$ dimension. - The $\mathrm{shift\_init}$ function creates a tensor with a new neighbor dimension, $\mathrm{NB}_x$, where $x$ is a counter. Each additional $\mathrm{shift\_init}$ will increase this counter. That is, after $n$ partial shifts, there will be $n$ sequentially numbered neighbor dimensions. - The $\mathrm{shift\_apply}$ function removes the neighbor dimension _with the highest index_. This ensures correct ordering of nested reductions with multiple partial shifts. ### General Shifts Iterator IR’s general variadic shift function with overloads for Cartesian and unstructured shifts can easily be implemented by dispatching the shifts to one of the specific Cartesian or unstructured functions described above. ### Can_deref On the one hand, $\mathrm{can\_deref}$ is not anymore strictly required, as neighbor table entries can now explicitly be queried because them being just tensors. On the other hand, tracking multiple partial shifts manually is cumbersome and error-prone, so it probably makes sense to keep it. The signature is: $$ \mathrm{can\_deref}::(\mathrm{Tensor}[T, D]) → \mathrm{Tensor}[\mathrm{bool}, D] $$ Note that there is no definition of a special ‘masked’ tensor type. Implementations are free to decide how masking is implemented. For performance reasons, it probably makes sense to define a single value per data type as masked, e.g., a NaN value for floats and the minimum negative values for signed integers. Note that the IEEE754 standard explicitly defines a ‘payload’ for NaN values, which can be used to store additional information (‘NaN-boxing’, https://en.wikipedia.org/wiki/IEEE_754#Representation%20and%20encoding%20in%20memory). But of course, a tensor could also have an additional (optional) reference to a boolean array with the masks values. Renaming the function would of course make sense, as there’s no $\mathrm{deref}$ anymore… ### Reduce The definition of $\mathrm{reduce}$ is very similar than in iterator IR. The dimension of the reduction is the neighbor dimension with the highest index. The signature looks as follows: $$ \begin{align} \mathrm{reduce} ::\; & ((\mathrm{Tensor}[T_r, D_r], \mathrm{Tensor}[T_1, D_1 \setminus \mathrm{NB}_x], …) → \mathrm{Tensor}[T_r, D_r], init) → \\ & (\mathrm{Tensor}[T_1, D_1], …) → \mathrm{Tensor}[T_r, D_r], \end{align} $$ where $D_r = \mathrm{bcast}((D_1 \setminus \mathrm{NB}_x) \cup …)$, that is, the broadcast of all arguments dimensions without the reduced neighbor dimension $\mathrm{NB}_x$ and $init$ is a scalar (non-tensor) literal of type $T_{r}$. Some notes: - The implementation of sparse fields as pre-shifted tensors becomes trivial. - A generalization to arbitrary reduction dimensions would be trivial. This specification tries to follow as closely as possible the corresponding definition of $\mathrm{reduce}$ in iterator IR. ### Scan The signature of $\mathrm{scan}$ is again very similar to iterator IR: $$ \begin{align} \mathrm{scan} ::\; & ((\mathrm{Tensor}[T_r, D_{r} \setminus D_s], \mathrm{Tensor}[T_1, D_1 \setminus D_s], …) → \mathrm{Tensor}[T_r, D_r \setminus D_s], forward, init) → \\ & (\mathrm{Tensor}[T_1, D_1], …) → \mathrm{Tensor}[T_r, D_r], \end{align} $$ where $D_s$ is the dimension to scan over, $D_r = bcast(D_1 \cup …)$, $forward$ is a boolean literal that defines the scan direction and $init$ is a scalar (non-tensor) literal of type $T_r$. ### Builtins for Domain Specifications No builtins are required for domain specifications (at the fencil level), as all domains are explicitly defined by the tensor size. But, to support assignments to a subset of a an output tensor or select only a subset of an input tensor at the fencil level, an additional builtin, here called $\mathrm{subset}$ can be introduced: $$ \mathrm{subset} :: (\mathrm{Tensor}[T, D_1], D_2) → \mathrm{Tensor}[T, D_2], $$ where for the sets of dimensions $D_1$ and $D_2$, the following holds: $D_2 = \{name: i \;|\; name: j \in D_1 \;\land\; i \subseteq j \}$. That is, $D_2$ only contains subintervals of dimensions also defined in $D_1$. Further, to support enable straightforward extent/dependency analysis for Cartesian boundary conditions, a concatenation function could make sense. That is, something like the following: $$ \mathrm{concat} :: (\mathrm{Tensor}[T, \{name: i_1\} \cup D], \mathrm{Tensor}[T, \{name: i_2\} \cup D], …) → \mathrm{Tensor}[T, \{name: i_1 \cup i_2\} \cup D], $$ where $i_1, i_2, …$ are contiguous disjunct intervals. ### New Builtin for Accessing Dimensions and Intervals Instead of using positional or tensors (which is easy to do), tensor IR could also provide a builtin for directly extracting the interval of a dimension of an arbitrary expression. That is, something like the following: $$ \mathrm{pos} :: (name, \mathrm{Tensor}[T, \{name: i\} \cup D]) → \mathrm{Tensor}[\mathrm{int}, \{name: i\}]. $$ ### New Builtin for Manually Adding New Dimensions For creating neighbor tables on the fly (and maybe for other applications), it could make sense to include a builtin function that allows for manual broadcasting; that is, manually specifying new dimensions: $$ \mathrm{add\_dim} :: (name, start, stop, \mathrm{Tensor}[T, D]) → \mathrm{Tensor}[T, \{name: [start\,..\,stop]\} \cup D] $$ In combination with the $\mathrm{pos}$ builtin above, this could easily be used for arbitrary neighbor lookups along a single axis (for example vertical indirection). ## Notes on Domain Sizes The above specification is based on compile-time domain sizes. In my (Felix Thaler’s) opinion, this is the only reasonable choice for the following reasons: - For fully correct dependency analysis, the domain size _must_ be known as soon as boundary conditions come into play. - For distributed applications, a large part of the structure of the domain decomposition _has_ to be known at compile time (otherwise we would have to assume that neighbor tables could be randomly shuffled). - Analysis of run-time unstructured domains or mixed run-time/compile-time domains is very hard. - We do weather and climate simulations. That means, we run on the same domain over and over and over again and the domain is constant during run-time. There might be a couple of grids, but daily simulations or long running climate simulations might heavily profit from compile-time knowledge of the domain. Nevertheless, the presented model could also support dynamic tensor sizes by introducing symbolic variables and expressions there, exactly like DaCe does for dynamic array sizes. ## Fencil Level of Tensor IR At the fencil level, not much changes compared to the iterator IR. The most noticeable change is the lack of domain specifications of stencil closures, as the domain is implicitly defined by the size of the output tensor. See also [Notes on Domain Sizes](#Notes-on-Domain-Sizes). ## Lifting and Lowering ### Lifting from Iterator IR to Tensor IR Lifting from iterator to (fully typed) tensor IR is relatively straightforward using just ad-hoc polymorphism as long as the input arguments data types and sizes are known. ### Lowering from Field View Lowering from field view should be relatively easy, as the IR is much closer to field view than iterator IR. ### Lowering to DaCe In contrast to iterator IR, lowering tensor IR to DaCe should be trivial, as tensor types map trivially to DaCe. ### Lowering to Machine Learning Frameworks Lowering of tensor IR to modern machine learning frameworks like Jax or PyTorch is straightforward. ### Lowering to MLIR Lowering of tensor IR to MLIR can easily be achieved by using Jax; but also manual lowering to MLIR’s native dialects or (if desired) a custom MLIR dialect should be straightforward. ## Storing Intermediate Results into Temporary Buffers Iterator IR allows for two interpretations of derefed lifted function evaluation: either on-the-fly computation or computation to an intermediate (global) temporary buffer. While the first case is trivial to implement (inlining), the second requires a very complex transformation (‘PopupTmps’ in GT4Py), followed by complex dependency analysis to determine the buffer and domain size. In contrast, the transformation in tensor IR is quite trivially, as all tensors are explicitly sized and no lifting or derefing is required. For example, the consider the following 1D fencil in tensor IR ``` fencil( out: tensor<float, I[1:99]>, inp: tensor<float, I[0:100]> ) { out ← (λ(x: tensor<float, I[0:100]>) → ⟪I, 1⟫(cos(x)) + ⟪I, -1⟫(cos(x)))(inp) } ``` A simple analysis could find that the expensive expression `cos(x)` is evaluated twice, and should thus be stored into a global temporary buffer. As everything is explicitly sized, we already know the dimensions and size of the required temporary buffer. So the replacement is pretty simple: 1. Add a new argument to the fencil. Size and data type is already known. 2. Create a new stencil closure with the expression `cos(x)` as output and same inputs the original function; store the result into the temporary buffer. 3. Add a new parameter (here `y`) to the original function, replace the expression `cos(x)` by `y`, pass the temporary buffer as additional argument. 4. (Optionally) prune all unused function parameters/arguments. ``` fencil( out: tensor<float, I[1:99]>, inp: tensor<float, I[0:100]>, tmp: tensor<float, I[0:100]> ) { tmp ← (λ(x: tensor<float, I[0:100]>) → cos(x))(inp) out ← (λ(y: tensor<float, I[0:100]>) → ⟪I, 1⟫(y) + ⟪I, -1⟫(y))(tmp) } ``` Note that also compute extent analysis etc. becomes relatively simple using the fully typed and sized tensor representation: we can just slice the output of a computation and push the updated dimensions through the expression tree. Further, when using DaCe, Jax, PyTorch, or MLIR as a backend, of course the decision of inlining vs. introducing buffers can easily be left to those tools and thus this transformation might not even be required at the IR level. ## FAQ ### How do you know that it works? A working implementation that lifts iterator IR to tensor IR and lowers to Jax (and thus MLIR) for the evaluation is available here: https://github.com/fthaler/gt4py/tree/tensor-ir-pos ### What are the limitations compared to iterator IR? - Arbitrary tuples are not supported; that is, for example tuple of functions (note that they are neither supported by the iterator IR type checking pass, nor ever used in a test or example). - Just a single, special iterator IR computation available in the GT4Py tests fails to translate easily to tensor IR due to its stricter handling of location types compared to iterator IR: the on-the-fly computation of ICON’s sign field. However, this is a computation which does not have a field view equivalent and due to the explicit handling of neighbor tables in tensor IR, it should not be hard to manufacture an alternative implementation (which in turn would be incompatible with iterator IR). Further, it is also quite questionable that this is a reasonable computation to perform on the fly.