# [DaCe] Improve `concat_where` Lowering <!-- Add the tag for the current cycle number on top --> - Shaped by: Philip - Appetite (FTEs, weeks): - Developers: <!-- Filled in at the betting table unless someone is specifically required here --> ## Problem > Note: > This exact bug might have gone away due to a change in the ICON4Py user code. > However, the underlying problem should still be solved. In the GTIR to DaCe lowering the handling of `concat_where`s especially nested cases. Consider the following code, which is part of `fused_velocity_advection_stencil_8_to_13_corrector` (currently known as `interpolate_horizontal_kinetic_energy_to_cells_and_compute_contravariant_corrected_w`). The actual problem is, that sometimes some values are computed that are not needed. In full specialization mode, of the benchmark-repo, it is possible to remove some of them, this is achieved mostly through the splitting transformations and some dead data flow transformations. However, sometimes when using SL1 or when the producer is a Map it is not possible to do that. In that case the problem is that `__tmp10` has to be kept and can not be merged with `__tmp11`. The reason for this is that the additional/useless write to the last column of `__tmp10` has to go somewhere. ```python= @gtx.field_operator def _compute_contravariant_corrected_w( w: fa.CellKField[wpfloat], contravariant_correction_at_cells_on_half_levels: fa.CellKField[vpfloat], nflatlev: gtx.int32, nlev: gtx.int32, ) -> fa.CellKField[vpfloat]: contravariant_corrected_w_at_cells_on_half_levels = concat_where( (nflatlev + 1 <= dims.KDim) & (dims.KDim < nlev), astype(w, vpfloat) - contravariant_correction_at_cells_on_half_levels, astype(w, vpfloat), ) return concat_where(dims.KDim == nlev, 0.0, contravariant_corrected_w_at_cells_on_half_levels) ``` This leads to the following SDFG (for clarity the `RemoveCopyChain` transformation is disabled, but it is unable to handle that case anyway): ![Initial_situation](https://hackmd.io/_uploads/r1xvG0EIge.png) The problem is that `__tmp10` materializes in full, i.e. `(81, )` although the last level is not needed, which is indicated that in the copy to `__tmp11` only the subset `[0:80]` is copied. The last column of `__tmp11` is initialized by the right Map in the picture above. This is a big problem for the optimizer, because it starts on the assumption that everything that is computed is (potentially[^potentially_needed]) needed. Another issue is, that depending on the order of the `concat_where` expressions, the resulting SDFG are very much different. For example, after some experimentation we found the following variant, which is especially suited and in fact desired: ```python= @gtx.field_operator def _compute_contravariant_corrected_w( w: fa.CellKField[wpfloat], covar_corr_at_cells_on_half_levels: fa.CellKField[vpfloat], nflatlev: gtx.int32, nlev: gtx.int32, ) -> fa.CellKField[vpfloat]: contravariant_corrected_w_at_cells_on_half_levels = concat_where( dims.KDim < (nflatlev + 1), astype(w, vpfloat), concat_where( dims.KDim < nlev, astype(w, vpfloat) - covar_corr_at_cells_on_half_levels, 0.0 # If `nlev` is not the last level, we would need: # concat_where(dims.KDim == nlev, 0.0, astype(w, vpfloat)) ) ) return contravariant_corrected_w_at_cells_on_half_levels ``` This snipped leads to the following SDFG: ![Best_attempt](https://hackmd.io/_uploads/Hy_ghA48ex.png) ## Appetite <!-- Explain how much time we want to spend and how that constrains the solution --> ## Solution The solution would be to modify the lowering in such a way that the intermediates "just have the right sizes". Thus in the firs/initial example that `__tmp10` does not include the last level. Furthermore, as it was shown the resulting SDFG should no longer or to a lesser extend, depend on the user code (might not be possible to fix in the lowering, but not an issue if the intermediates are generated in a smarter way). ## Rabbit holes <!-- Details about the solution worth calling out to avoid problems --> ## No-gos <!-- Anything specifically excluded from the concept: functionality or use cases we intentionally aren’t covering to fit the ## appetite or make the problem tractable --> ## Progress <!-- Don't fill during shaping. This area is for collecting TODOs during building. As first task during building add a preliminary list of coarse-grained tasks for the project and refine them with finer-grained items when it makes sense as you work on them. --> - [x] Task 1 ([PR#xxxx](https://github.com/GridTools/gt4py/pulls)) - [x] Subtask A - [x] Subtask X - [ ] Task 2 - [x] Subtask H - [ ] Subtask J - [ ] Discovered Task 3 - [ ] Subtask L - [ ] Subtask S - [ ] Task 4 <!-- ================================================== --> [^potentially_needed]: A case were this might be violated is in indirect accessing, i.e. in expressions such as `intermediate_result[runtime_table[i]]`, here the full array `intermediate_result` has to materialized.