# Scans in the IR & Field View Frontend
## Meeting 2022-05-11
## Current IR
### Python Example
```python=
@fundef
def tridiag_forward(state, a, b, c, d):
def initial():
return make_tuple(deref(c) / deref(b), deref(d) / deref(b))
def step():
return make_tuple(
deref(c) / (deref(b) - deref(a) * tuple_get(0, state)),
(deref(d) - deref(a) * tuple_get(1, state))
/ (deref(b) - deref(a) * tuple_get(0, state)),
)
return if_(is_none(state), initial, step)()
@fundef
def tridiag_backward(x_kp1, cp, dp):
return if_(is_none(x_kp1), deref(dp), deref(dp) - deref(cp) * x_kp1)
@fundef
def solve_tridiag(a, b, c, d):
tup = lift(scan(tridiag_forward, True, None))(a, b, c, d)
cp = tuple_get(0, tup)
dp = tuple_get(1, tup)
return scan(tridiag_backward, False, None)(cp, dp)
```
### C++ Example
```cpp=
struct forward_scan : fwd {
static constexpr auto prologue() {
return tuple(scan_pass(
[](auto /*acc*/, auto const & /*a*/, auto const &b, auto const &c, auto const &d) {
return tuple(deref(c) / deref(b), deref(d) / deref(b));
},
identity()));
}
static constexpr auto body() {
return scan_pass(
[](auto acc, auto const &a, auto const &b, auto const &c, auto const &d) {
auto [cp, dp] = acc;
return tuple(
deref(c) / (deref(b) - deref(a) * cp), (deref(d) - deref(a) * dp) / (deref(b) - deref(a) * cp));
},
identity());
}
};
struct backward_scan : bwd {
static constexpr auto prologue() {
return tuple(scan_pass(
[](auto /*xp*/, auto const &cpdp) {
auto [cp, dp] = deref(cpdp);
return dp;
},
identity()));
}
static GT_FUNCTION constexpr auto body() {
return scan_pass(
[](auto xp, auto const &cpdp) {
auto [cp, dp] = deref(cpdp);
return dp - cp * xp;
},
identity());
}
};
constexpr inline auto tridiagonal_solve =
[](auto executor, auto const &a, auto const &b, auto const &c, auto const &d, auto &cpdp, auto &x) {
using float_t = sid::element_type<decltype(a)>;
return executor()
.arg(a)
.arg(b)
.arg(c)
.arg(d)
.arg(cpdp)
.arg(x)
.assign(4_c, forward_scan(), tuple<float_t, float_t>(0, 0), 0_c, 1_c, 2_c, 3_c)
.assign(5_c, backward_scan(), float_t(0), 4_c);
};
```
### Differences
```haskell=
type ColumnStencil = ColumnIterators… → ColumnValue
type ScanFun = Accumulator → ScalarIterators… → Accumulator
-- Current Python API
scan :: ScanFun → Bool → Init → ColumnStencil
scan body is_forward initial_accumulator_value = …
-- C++-like
type ScanFunWithProjection = (ScanFun, (Accumulator → SubsetofAccumulator))
scan :: [ScanFunWithProjection] → ScanFunWithProjection → [ScanFunWithprojection] → Bool → Init → ColumnStencil
scan prologues body epilogues is_forward initial_accumulator_value = …
-- There is also a fold! This can not be used on its own in a useful way, but…
fold :: [ScanFun] → ScanFun → [ScanFun] → Bool → Init → ??
fold :: prologues body epilogues is_forward initial_accumulator_value = …
-- … also be passed to merge, which creates a special composition of scans and folds:
-- it passes the final value of the accumulator to the next scan or fold
-- this allows a reduction result (computed using fold) to be broadcast
-- to a whole column (using scan)
merge :: ScansOrFolds… → ColumnStencil
```
* **Prologues and epilogues.** The Python tests currently use `is_none` to distinguish between the first and other levels. This strategy obviously fails for the second and following levels and also for levels relative to the loop end. The C++ implementation allows for an arbitrary number of `prologues` and `epilogues`, that is, levels at the beginning and end of the loop, respectively. A transformation from the former to the latter is in theory possible but probably non-trivial in the general case.
* **Projections.** The C++ backend supports explicit ‘projection’ of the accumulator. This enables accumulator types that contain more data than the return value. For example, in the COSMO vertical advection solver, the solution of the tridiagonal system as well as a time-integrated version thereof are returned by the scan function. However, only the time-integrated version is stored in memory for later use; the unmodified solution variable of the linear system of one level is just passed to the next level, but never stored in memory. This optimization could probably be done at a later stage automatically (removal of write-only column fields), but not within the IR as there is no way to represent these scans. Further, support for explicit projection might be more user-friendly as it allows users to specify what they really want to return and what is only there for accumulation.
* **Folds**. Additionally to scans, the C++ backend supports folds. They do not write anything to memory but just accumulate a value. The only usefulness comes together with the following:
* **Merged stages.** The C++ backend has another special feature: it allows for merging multiple scans and folds into a single one while passing the accumulator through all stages. That is, the accumulator is not reinitialized after a scan or fold, but instead the result is just passed as initializer to the next scan or fold. This allows for column-wise reductions followed by broadcasting: a fold can be used to accumulate a value in a register; a following scan can broadcast that value to the whole column.
## Field View
Tricky part: we work on 2D slices of 3D fields.
```python=
# Slice[[I, J, K], float]: a horizontal slice of Field[[I, J, K], float]
# can be accessed with offsets
# Slice[float]: a horizontal slice of temporary state data
# can not be accessed with offsets
def tridiag_forward():
# optionally can pass a projection function (defaults to identity)
@scan_pass(projection=lambda state: state)
def initial(
state: tuple[Slice[float], Slice[float]],
a: Slice[[I, J, K], float],
b: Slice[[I, J, K], float],
c: Slice[[I, J, K], float],
d: Slice[[I, J, K], float],
) -> tuple[Slice[float], Slice[float]]:
return c / b, d / b
@scan_pass
def step(
state: tuple[Slice[float], Slice[float]],
a: Slice[[I, J, K], float],
b: Slice[[I, J, K], float],
c: Slice[[I, J, K], float],
d: Slice[[I, J, K], float],
) -> tuple[Slice[float], Slice[float]]:
cp, dp = state
return c / (b - a * cp), (d - a * dp) / (b - a * cp)
# no initializer for the state, so it’s undefined in `initial`
return forward_scan(step, prologues=[initial])
# simpler implementation without prologues, but making use of init
def tridiag_forward2():
@scan_pass
def step(
state: tuple[Slice[float], Slice[float]],
a: Slice[[I, J, K], float],
b: Slice[[I, J, K], float],
c: Slice[[I, J, K], float],
d: Slice[[I, J, K], float],
) -> tuple[Slice[float], Slice[float]]:
cp, dp = state
return c / (b - a * cp), (d - a * dp) / (b - a * cp)
# init is used to avoid prologue
return forward_scan(step, init=(0.0, 0.0))
def tridiag_backward():
@scan_pass
def initial(
x_kp1: Slice[float],
cp: Slice[[I, J, K], float],
dp: Slice[[I, J, K], float],
) -> Slice[float]:
return dp
@scan_pass
def step(
x_kp1: Slice[float],
cp: Slice[[I, J, K], float],
dp: Slice[[I, J, K], float],
) -> Slice[float]:
return dp - cp * x_kp1
return backward_scan(step, prologues=[initial])
# simpler implementation using init instead of prologue
def tridiag_backward2():
@scan_pass
def step(
x_kp1: Slice[float],
cp: Slice[[I, J, K], float],
dp: Slice[[I, J, K], float],
) -> Slice[float]:
return dp - cp * x_kp1
return backward_scan(step, init=0.0)
@field_operator
def solve_tridiag(
a: Field[[I, J, K], float],
b: Field[[I, J, K], float],
c: Field[[I, J, K], float],
d: Field[[I, J, K], float]
) -> Field[[I, J, K], float]:
cp, dp = tridiag_forward()(a, b, c, d)
return tridiag_backward()(cp, dp)
```
## Questions
- general condition (e.g. `if(is_none(...))` or `if(k==0)`) vs prologue, epilogue
- ...
- what is the state (non-shiftable field), what are the input fields (2D view of a 3D view)
- we use 0-d fields -> cannot shift anything
- we can always pass the same field multiple times after shifting
```python
Accumulator = Annotated[T, "acumulator_field"]
# @scan_operator(init=0.0, direction=backward)
# def tridiag_backward(
# cp: float,
# dp: float,
# *,
# x_kp1: LiteralValue = INIT
# ) -> float:
# return dp - cp * x_kp1
# tridiag_backward(cp, dp)
@scan_operator(init=0.0, direction=backward)
def tridiag_backward(
x_kp1: float,
cp: float,
dp: float,
) -> float:
return dp - cp * x_kp1
@field_operator
def solve_tridiag(
a: Field[[I, J, K], float],
b: Field[[I, J, K], float],
c: Field[[I, J, K], float],
d: Field[[I, J, K], float]
) -> Field[[I, J, K], float]:
cp, dp: Field[[I, J, K], float] = tridiag_forward(a, b, c, d)
return tridiag_backward(cp, dp)
```
### 3 use-cases
- vertical advection
- FVM
- Lagrangian e.g. https://gist.github.com/havogt/c52d19f2f7557c2c048d12be7330bda8