# A performance-portable JAX backend using DaCe parallel programming framework ## Description [JAX](https://jax.readthedocs.io/en/latest/#) provides a functional NumPy-like API for numerical computing combined with a powerful system to implement composable function transformations which supports automatic differentiation (`grad`, `jvp`, ...), optimization (`jit`), vectorization (`vmap`) or parallelization (`pmap`) of functions. Its performance and flexibility has turned JAX into one of the most exciting and easy to use frameworks for numerical computing and machine learning in Python. For example, the `jit` function transformation takes in a function and returns a semantically identical function which is internally compiled to binary code for the hardware available (including accelerators). ```python x = random.normal(random.PRNGKey(0), (5000, 5000)) def f(w, b, x): return jnp.tanh(jnp.dot(x, w) + b) fast_f = jit(f) ``` JAX captures the computation graph of _pure_ user-defined functions in an evaluable data structure called **Jaxpr** (JAX expression) which is used as intermediate representation for function transformations. These transformation are implemented as JAX interpreters, which are Jaxpr evaluators providing alternative interpretations of the primitive operations recorded in the computation graph. [DaCe](https://github.com/spcl/dace) is a parallel programming framework developed at the Scalable Parallel Computing Laboratory (SPCL) in the Department of Computer Science at ETH Zurich. SPCL is a relevant CSCS partner which contributes to the development of different projects at the Scientific Software Development (SSD) working structure, like [GT4Py](https://github.com/gridtools/gt4py) or [COSMA](https://github.com/eth-cscs/COSMA). DaCe is able to take code in different programming languages and map it to high-performance CPU, GPU, and FPGA programs, which can be optimized to achieve state-of-the-art performance. Internally, DaCe uses the Stateful DataFlow multiGraph (SDFG) data-centric intermediate representation: a transformable, interactive representation of code based on data movement. In this internship we would like to explore the feasibility of using numerical functional programming frameworks like JAX as frontends for HPC optimization frameworks like DaCe. We propose to add a new JAX function transformation to use DaCe as a code-generation backend for JAX by creating DaCe SDFGs directly from Jaxprs. Ideally, at the end of the internship JAX users could compile JAX programs using DaCe very cleanly by just decorating functions with the new transformation (e.g. `@dace_jit`). This project will allow the intern to gather practical knowledge and hands-on experience with the internals of industry-grade frameworks like JAX, state-of-the-art optimizing compilers like DaCe, and the connections between them. Additionally, the intern will have the opportunity to present the work and the acquired knowledge to different CSCS teams, which could help in future design and implementation decisions. Candidates should be competent and independent in Python. Solid understanding of functional programming, compilers and performance optimizations is desiderable but optional. ## Milestones 1. Understanding core JAX concepts (Primitives, tracing, Jaxprs and interpreters) 2. Understanding core DaCe concepts (SDFGs and transformations) 3. Develop a new JAX transformation to create SDFGs from Jaxprs 4. Unoptimized code-generation and compilation of the obtained SDFGs 5. Validation of results from different JAX programs 6. Implement additional optimizations in DaCe 7. Final benchmarking of the DaCe transformation using a HPC benchmark suite (e.g. [pyhpc](https://github.com/dionhaefner/pyhpc-benchmarks)) ## Details Supervisors: Enrique G. Paredes, Linus Groner Working place: ETH Zurich Expected duration: 4 months ## References - \[JAX\]: https://jax.readthedocs.io/en/latest/index.html - \[JAX Transformations]\: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html - \[DaCe\]: https://github.com/spcl/dace - \[DaCe paper\]: https://arxiv.org/pdf/1902.10345.pdf - \[HPC benchmarks for Python - pyhpc\]: https://github.com/dionhaefner/pyhpc-benchmarks