# [DaCe] Improving the Jax to DaCe (J2D) Translator ###### tags: `cycle 21` <!-- Add the tag for the current cycle number on top --> - Shaped by: [Philip](mailto:philip.mueller@cscs.ch), [Edoardo](mailto:edoardo.paone@cscs.ch) - Appetite (FTEs, weeks): - Developers: <!-- Filled in at the betting table unless someone is specifically required here --> ## Problem In [cycle 18](https://hackmd.io/@gridtools/rkj1xNGmT) (last cycle of 2023) a [Jax to DaCe translator](https://github.com/philip-paul-mueller/jax_to_dace) (J2D) was implemented. However, it remained to this day in a very early prototype/proof-of-concept state. During the QPM for 24Q2 it was decided that J2D should be kept alive and be turned into a production ready code. There are several reasons for this: - It will bring all Jax features, mainly automatic differentiation, to DaCe. - It could serve as kind of a case study for how to create the CombinedIR in a DaCe friendly way. - It could serve as entry point for embedded execution into DaCe. However, there are several issues with the current implementation, here only the most severe ones are listed: - Only a limited, while important, subset of Jax' primitives are supported. - Due to some internal changes to Jax itself the translator only works with version `0.4.20`. The main reason is because we kind of depend on the behaviour of Jax internals. - While there are tests, they are not systematic and have to be run manually. - There is no `@jit` equivalent yet, the whole process has to be guided manually. This means that the Jaxpr must be created as first step and then be given to the translator. - The implementation has a lot of technical debt due to its evolution. - Most importantly, the current project targets stencil operations as they are encountered in ICON. <!-- Althought `jax.grad` can be handled as it is just translated onto a different Jaxpr, but I am not sure about `vmap` and stuff. --> On the other hand, the _design_ of the translator, i.e. that each primitive is handled by a dedicated object and the whole process is driven by some parent object, is very promising and should be kept. An example for a bad decision is how subexpressions, i.e. `pjit` and `cond`, are handled (see later): the design supports arbitraraly nested expressions, but its implementation is quite hacky. ## Appetite The whole productization will take several cycles, as Jax has a lot of features that we might have to cover. However, as a goal, we should aim to support the full [pyhpc benchmark suite](https://github.com/dionhaefner/pyhpc-benchmarks/tree/master) by the end of cycle **22**. ## Solution In the following it is assumed that the reader is familiar with the [Jaxpr language](https://jax.readthedocs.io/en/latest/tutorials/jaxpr.html) that describes Jax operations. The process is organized into two phases that are now covered in detail. #### Phase 1) Defining of the Interfaces/Classes As of now there are three important classes and chances are high that we will keep them, at least functionality wise. In the first step we will first study the current implementation and its shortcomings and identify potential ways to avoid them. Another problem is that some terms are used for related but different things and also in a not so coherent way. Thus we should also create a glossary for definitions with a proper wording and stick to them. ###### `JaxIntrinsicTranslatorInterface` A class that is derived from this interface implements the translation of a _single_ Jax equation/primitive, a simple example would be the addition of two fields. It was designed to keep as much business logic out of them, a goal that was indeed successful. They should follow the UNIX philosophy of "do one thing, but do it good". Currently the interface defines the following functions (likely to retain them): ```python def translateEqn( self, translator, inVarNames: list[Union[str, None]], outVarNames: list[str], eqn: JaxprEqn, eqnState: Union[dace.SDFGState, None], ): ... ``` This function creates the actual SDFG representation of the current Jax equation, which is passed as `eqn`. The equation should be constructed inside `eqnState` (might be `None` in some cases). The driver translator, which is passed as `translator`, is responsible for setting up the context, such as ensuring that all input/output variables are created inside the SDFG. ```python def canHandle( self, translator, inVarNames: list[Union[str, None]], outVarNames: list[str], eqn: Union[dace.SDFGState, None], ): ... ``` As its name says, this function is used by the driver, which is passed as `translator`. With the exception of `eqnState` it receives the same arguments as the `translateEqn()` function. Its main task is to determine if the translator is _able_ to handle the primitive. As a side node, most of the translators only check the name of the primitive, but in principle more elaborated checks are supported. The two functions above will most likely stay as they have proven useful. For some primitives and in some situations the need arises that the state should not be allocated by the driver but by the equation translator itself. To signal this to the implementation, we use the function: ```python def needsExternalState( self, translator, inVarNames: list[Union[str, None]], outVarNames: list[str], eqn: JaxprEqn, ): ... ``` This function should also be retained, but its usage in the main driver and its semantics should be made more clear. Most likely all these translators will be reused. ###### `JaxprBaseTranslator` This is in some way the driver translator, however, it is not directly exposed to the user (it should be renamed, this time with a leading `_`). This class contains, among others, the following members: - A list of all intrinsic translators. - A map for mapping the Jax names to the translators. - If a translation is active, it also contains the SDFG and the current head state, i.e. into this we will put the next equation. - A list of names that should be avoided. The reason for this list is, that Jax uses `a`, `b`, .... and so on for its internal variables. However, since for example `w` is a quite popular variable name, i.e. as argument, the Jax variable `w` must be named differently. There is also a list of variable names that are forbidden, such as `if`, `auto` which are valid Jax names. Its main functions are the following: ```python def translateJaxpr( self, jaxpr: ClosedJaxpr, ..., ): ... ``` This is the main function, essentially it receives a `Jaxpr` object and then iterates through it. For each equation it finds the respective translator and calls it. Its whole inner implementation will need some consolidation, one problem was that the object was once designed to be stateless, which was not a good idea, but still some aspects remain. Furthermore, the translators are found by a linear scan, which should be changed a bit. There are several functions that allows the translator to handle nested Jaxpr expressions, i.e. `cond` and `pjit` primitives. While the design is quite nice, its execution is not good. ```python def clone(self) -> JaxprBaseTranslator: ... ``` As its name is suggesting this function returns a new instance of a `JaxprBaseTranslator`. However, the newly created object is not a clone of the current state, but just after it was initially constructed. This object is then used to translate the nested expression. In order to handle the nested Jaxpr object other things are needed, mostly centering around the `TranslatedSDFG` object. While the idea is attractive, since it fully isolates handling nested Jaxpr to some isolated locations, its current implementation is not good and must be redone. Especially the cloning of objects has currently some associated wierdness. Implementing nested Jaxpr in this way, i.e. putting them inside their own state allows to inline this SDFG. Another important function that should be retained is ```python def _addArray( self, arg: JaxVariable, ..., ) -> str: ... ``` This function translates a Jax variable into its equvalent SDFG array variable. This job is not simple, since the function has to ensure that there is no name clash (while avoiding any forbidden name). Note that this function also has the potential to generate fully symbolic arrays, i.e. the shapes of the arrays are not known at compile time, but it is disabled. In light of future goals we should enable this functionality. With the exception of cloning, this class might be considered to be quite clean. ##### Additional Things There are some additional things that are supported in Jax but that are specific to its paralellization schemes. Example for these are the `pmap` and `vmap` primitives. Since they expand to some `XLA` instruction, such as `xla_pmap` it is not so clear what should we do with them in DaCe. ###### `JaxprToSDFG` This class is the translator that is used by the user, i.e. the user passes a `Jaxpr` object and some constraints, like if it is on GPU, argument names, how the return type should be addressed, and `JaxprToSDFG` returns an SDFG. Originally, there was only this class, but at some point a lot of its functionality was moved into `JaxprBaseTranslator`. What remained inside this class was the function to create a proper SDFG that could be used. For example it would generate the special names `__return` or `__return_{}` and setting up the correct argument names, as this functionality is only needed at the toplevel SDFG and not in nested SDFG. This class needs a serious overhaul, however, it should stay. To implement a jit decorator, such as `@jax2dace.jit`, should be most likely be implemented as a wrapper. #### Phase 2) Actually Doing It After we have identified the short-comings and issues of the current implementation and how we can/should/want solve them we implement this. It is very likely that we can use some code from the prototype in this new project, a candidate for this would be the primitive translators. However, we will start with an empty git repo, inside the `eth-cscs` organization and work with PR, even if it is just the importing of prototype code. ###### Styling & Coding Guide The prototype does not have a style guide and despite written by a single person follows two different styles (the story behind it is funny/sad), but the new prototype will follow _a_ standard a bit more. We summarize here the most important topic (one task is to write actually a guide). - The style should always follow the formatting of `vim`, i.e. if you want to format your code run the `GG=G` command. - Variables should be named in camel case starting with a lower case character. - You can add `_` to indicate more separation between a group f words. - Member variables of classes follow the pattern `__m_nameOfTheMemberVar`. However, in exceptional cases you can create "public" member variables, but they are the exception. - Do not use `property` as they hide stuff. - At the end of every scope put an `#end` to indicate this, examples are: ```python def myFunction(...): # Doing some stuff return x # end def: ``` ```python for i in range(10): ... # end for(i): ``` - Comments and documentation are not evil, use them. A function should generally have a description of its arguments, what it does and what its _intent_ is. - In my (Philip) view type annotations go against the core of Python, if you want types use a typed language to begin with. However, since they are needed use them especially in the argument signature. - Be explicit. ###### Tests Tests are important, so you should add them. In the beginning the original test (which is basically a jupyter notebook) should suffice, as they should cover any codepath in the subtranslators. However, they should be ported to use `pytest`. Another idea would be to make use of the Jax tests. However, as they cover all of Jax to support them would be too much. ###### Project Scope and Aims The prototype explicitly aimed at ICON stencils, however, this new J2D translator aims at supporting Jax and bridges Jax applications to DaCe. Furthermore, we have explicitly decided to create a dedicated repo for the translator and not make just another interface to DaCe. ## Rabbit Holes While the final goal is to create a Jax to DaCe translator, the goal of this cycle is to polish the current prototype, to act as a solid base for future changes. The bare-bone translator, i.e. the successor of `JaxprToSDFG`, will not support [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html). Instead such a feature should be handled by our equivalent of the `@jit` annotation. In addition, the translator _can_ not support something as an SDFG, at least in Python world, is always transformed to a `CompiledSDFG`. Thus, the best action is to implement pytrees inside DaCe. Then we also ignore [sharding](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), which is Jax' way to parallelize. The reason for this is that it is way out of the scope and that there is not yet a concept in DaCe itself for it. The current way to "disable" it, which is also what the prototype does, is to translate a `device_put` primitive into a copy instruction. In addition we will not implement additional primitives above what we already have. At least in the next cycle we have to implement `scan` and `scatter`. ## No-gos <!-- Anything specifically excluded from the concept: functionality or use cases we intentionally aren’t covering to fit the ## appetite or make the problem tractable --> ## Progress - [x] Setup of the [JaCe repo](https://github.com/GridTools/jace). - [ ] Initial [PR](https://github.com/GridTools/jace/pull/3) (still under review). <!-- -------------------------------------------- -->