# [GT4Py] Workflowify toolchain
<!-- Add the tag for the current cycle number in the top bar -->
- Shaped by:
- Appetite (FTEs, weeks):
- Developers: <!-- Filled in at the betting table unless someone is specifically required here -->
## Problem
Currently we start with workflows after lowering to itir. We should apply the strategy to express all toolchain steps.
This enables new features, like
- backends/transformations starting earlier than itir, e.g. jax.jit
- wrapping/unwrapping transformations for non-gt4py fields, e.g. possibly xarrays
but also is a general cleanup.
## Appetite
Full cycle, 2 developer or 1 + support.
## Solution
In a first iteration, we propose to keep the changes transparent to users, i.e. backends should behave as before. They will contain an allocator part and a transformation part.
### Refactoring decorator.py
The main work is to refactor `decorator.py` by moving all hard-coded transformations to workflow steps to be used in backends. Consider partially refactoring `decorator.py` before workflowifying it.
For `Program`, the refactorings are (possible incomplete):
- `_process_args`: part of transformations
- `itir`: part of transformations
- `__post_init__`: discuss if this early checking should be kept, if yes, we should use the new steps to get the itir representation. However, ideally, (as a follow-up?) the check should not rely on itir.
- `from_function`: part of transformations, currently includes lowering to past. If this should (as now) run as part of creating the program (module load time / after construction with `with_backend` etc), a backend transformation workflow should contain steps that run at construction and steps that run on execution.
For `FieldOperator`, these are:
- `from_function`: see above
- `__gt_itir__`: has a caching mechanism that would have to be translated to the workflow step
### Non-trivial refactorings to enable starting from Python functions
.. they are currently wrapping lower-level IRs around their own implementation.
#### `field_operator.as_program()`
Creates a past node and then constructs a `Program`.
Ideally, we would create a Python function that can then be used as normal program, i.e. all transformations would work normally. During transition in this project, we could consider injecting an additional step into backends if they have a respective slot for this operation. The decision should be taken to fit into the appetite of the project.
##### Refactor to a Python function
In case the decision is to do the full refactoring to a Python function, this is the sketch of the involved steps:
- generate the Python source code of a program calling the field_operator/scan_operator
- `eval` the source code in context of the field_operator
Note: to make parsing work with the eval'ed function there are 2 options
1. Store the source code along with the eval'ed function in the program (probably the easier solution if it covers all use-cases)
2. Update the `linecache` along the lines of https://stackoverflow.com/a/50885938/5085250
#### `ProgramWithBoundArgs`
Similar to [`field_operator.as_program()`](#field_operatoras_program), as a transition step we could inject the transformations into the backends.
##### Refactor to Python function
The clean solution would generate a new Python function with the bound arguments. The steps would be:
- extend Program/PAST to allow variable definitions (requires lowering)
- generate a new Python function with the removed bound args
- to the body, add the variable definitions with the bound values
### Additional notes
- All itir transformations (pass_manager.apply_common_transformations) should be a single step for this project.
- The program has the `format_itir` function, which is probably mainly used by GT4Py developers. We could consider removing it and provide a convenient formatter as backend or keep it with an hard-coded formatter for convenient use.
## Rabbit holes
See `field_operator.as_progam()` and `ProgramWithBoundArgs`. TODO describe the way out.
## No-gos
<!-- Anything specifically excluded from the concept: functionality or use cases we intentionally aren’t covering to fit the ## appetite or make the problem tractable -->
## Progress
<!-- Don't fill during shaping. This area is for collecting TODOs during building. As first task during building add a preliminary list of coarse-grained tasks for the project and refine them with finer-grained items when it makes sense as you work on them. -->
- [x] Add `PAST` -> `ITIR` workflow step to `GTFNBackend`
- [x] add step
- [x] add code path in `Program.__call__` to run this step
- [x] get tests with gtfn to pass
- [x] Add `PAST` -> `ITIR` step to all backends
- [x] roundtrip
- [x] double roundtrip
- [x] dace
- [x] remove special treatment in `Program.__call__`
- [x] Add `PAST` -> function definition step
- [x] add step
- [x] add code path in `Program.__call__` to generate `definition` if it doesn't exist
- [x] add a test that executes `FieldOperator.as_program` manually and runs the result in embedded mode
- [x] Refactor `Backend` and `ModularExecutor`
- [x] assumption that `Backend` is a `ProgramExecutor` is now false, update test utils etc
- [x] all tests green
- [x] Merge [#1479](https://github.com/GridTools/gt4py/pull/1479)
- [x] remove duplicate code paths
- [x] move args processing into lowering step?
- [x] split off `PAST` -> function definition: [#1487](https://github.com/GridTools/gt4py/pull/1487)
- [x] cleanup?
- [x] Add function definition -> `PAST` workflow step
- [x] create step and stage that works only with program function definitions
- [x] design a way to skip that stage for `FOAST` -> `PAST`
- consider non-linearity in workflows / step with alternate code paths
- consider stage that can contain either
- [x] attach to all backends
- [x] merge [#1500](https://github.com/GridTools/gt4py/pull/1500)
- current design wrt the considerations below:
- embbedded is still `backend=None`
- arguments are handed to the transforms workflow, which chooses when to inject it
- steps of the transform workflow are accessible and can be used for partial transformations (with or without arguments)
- function -> IR is part of the backend
- [x] Add `FOAST` -> `PAST` workflow step
- [x] create step and stages
- [x] design way to start from `FOAST` or `PAST` / program function
- without having to pass `FOAST` -> `PAST` workflow to `field_operator` separately
- without overcomplicating the backend
- while keeping things modular
- [x] add to all backends
- [ ] merge
- [x] Add field operator definition -> `FOAST` step
## Design considerations / open questions (non-blocking)
### What does it mean for embedded execution to be equivalent to "tranforms=None"?
Does it mean "no additional toolchain is run"? If so, what about DSL linting?
If no, then is `transforms=None` or `backend=None` really the interface we should aim for?
### When should the arguments be injected?
Currently the decorator does all the transformations that don't require program arguments.
The backend is always called with program arguments and does the rest. If everything becomes a single workflow, it becomes an open question when the program arguments should be injected.
In other words, what is on-demand and what is jit? If arguments are always injected before calling the toolchain, then DSL linting is no longer possible on-demand.
### Should the transformations "function -> IR" be part of the backend?
If yes:
```python!
# normal backend
@field_operator(backend=gtfn)
# alternative: transforms=gtfn.transforms
def foo(...):
...
# normal embedded
@field_operator(backend=None)
def foo(...):
...
# embedded with transforms
@field_operator(backend=Backend(transforms=jax_jit))
def foo(...):
...
class Backend:
transforms: Workflow[InputClosure, OutputClosure] = noop_transforms
jit: Workflow[OutputClosure, CompiledProgram] = noop_jit
executor: ProgramExecutor = noop_executor
def __call__(self, program: SomeClosure):
program_call = jit(transforms(program))
executor(program_call.program, *program_call.args, **program_call.kwargs)
def get_foast(...): ...
def get_past(...): ...
def get_itir(...): ...
def lint(...): ...
```
The backend will need to expose functionality like DSL linting, inspecting intermediate IRs etc. By imposing the structure of the transforms workflow this could be done without adding significant logic to backends.
Benefit: it is all automatically consistent with how this particular backend will do things.
If no (closer to current state):
```python!
# normal backend
@field_operator(backend=gtfn, transforms=default_transforms)
# alternative: transforms=gtfn.transforms
def foo(...):
...
# normal embedded
@field_operator(backend=None, transforms=None)
def foo(...):
...
# embedded with transforms
@field_operator(backend=None, transforms=jax_jit)
def foo(...):
...
# analogous for program
class FieldOperator:
backend: Optional[Backend]
transforms: Optional[Workflow[FieldOpDef, IR] | Workflow[FieldOpDef, CompiledProgram]]
# Workflow[FieldOpDef, IR]: transforms to something that can be passed to backends
# Workflow[FieldOpDef, CompiledProgram]: transforms to runnable program (jax.jit?)
def __call__(self, *args, **kwargs):
transformed = self.transforms(FieldOpCall(self.definition, args=args, kwargs=kwargs))
# should transforms get the args and kwargs? I believe jax requires them
if self.backend:
return self.backend(typing.cast(transformed, IRClosure))
else:
return run_embedded(transformed)
```
## Scratchpad
```python!
@field_operator(backend=gtfn)
def foo(a: Field) -> Field:
return a
# foo.with_backend(roundtrip)(a, out=bar)
class Program:
def __call__(*args):
callable = make_executable_workflow(backend)(self.definition)
transformed_args = make_transform_args_workflow(backend)(*args)
callable(*transformed_args)
backend(self.definition, *args, **kwargs)
class Backend
transforms: Workflow[definition, ProgramCall]
otf_workflow: Workflow[ProgramCall, CompiledProgram]
...
def __call__(self, program, *args, **kwargs):
self.otf_workflow(
self.transforms(definition)
)(*args, **kwargs)
prog(xarray, transformations=xarray_to_field)
foo_prog = foo.as_program()
foo_prog
```