# Design changes in pytensor to think about ## Fusion optimizer https://github.com/pymc-devs/pytensor/pull/121 - Are we fusing too much now (ignoring bugs)? - Should we try to restrict to Elemwise subgraphs that share something other than constants? - Should we try to only fuse "vectorizable" Ops (whatever those are, probably nothing from scipy.special) - Backend specific truncation rules (if on numba, don't merge Ops that are evaluated in object mode) - Don't run at all in JAX? ## RandomVariable in numba A PR was merged in numba for Generator support? We could use that to implement RandomVariables? Seems there is also a stale aesara PR already... ## Function object for numba and Jax We can get rid of a lot of C specific things. But how do we hanled updates and shared variables? It would be great if the compiled function could be called from numba/jax directly. (some work for shared variables is in nutpie). - Perhaps have a version without updates only (that raises if there were any) - Two levels: One that manages input validation and shared variables and the lower level one that requires everything be provided manually ## Compilation times with numba backend - Really bad right now? - Mostly because of numba itself? It compiles any intermediate function separately. Some other things: - pass no_cpython_wrapper=True where possible. - Make numba functions global so they can be reused. - impl mul/add as global functions. Changes to CAReduce necessary? But maybe merge that with Elemwise? - Add a decorator to funcify: `cache_by_op` or so, that reuses the numba function from a previous funcify call. ## Vectorization in numba There seems to be an issue in numba that sometimes prevents the llvm autovectorizer from working. https://numba.discourse.group/t/compilation-pipeline-compile-time-and-vectorization/1716 Maybe inline="always" can help? ## Scan Op redesign? - Experimenting with "Scalar Scan Op" in https://github.com/pymc-devs/pytensor/pull/174 https://www.sjalander.com/research/pdf/sjalander-tecs2020.pdf Let's try to implement a new loop op, and see if that can do what we did before? Long-term deprecate the old scan at some point? ## Types / Variables / .type attribute Max said something? ## numba llvm vs numba code? CAReduce? Maybe fuse those two very late in the pipeline for numba ## ifelse? It's not actually if-else right now, but a switch in numba I think we need subgraphs in IfElse? ```python= out = switch(cond, op1(a ,b), op2(c, d)) # def scalar_out(a, b, c, d): # _1 = op1(a, b) # _2 = op2(c, d) # out = cond ? _1, _2 ``` ```python= res = empty(cond.shape) res = set_subtensor(res[cond], op1(a[cond], b[cond]), op2(c[cond], d[cond])) res = set_subtensor(res[~cond], op1(a[~cond], b[~cond]), op2(c[~cond], d[~cond])) ``` ```python= class IfElse(Op): __props__ = ("if_branch", "else_branch") def __init__(self, if_branch: FunctionGraph, else_branch: FunctionGraph): self.if_branch = if_branch self.else_branch = else_branch # assert same input and output types def make_node(self, condition, *inputs): return Apply([condition, *inputs], self.if_branch.outputs) def perform(self, inputs, output_storage): condition, *inputs = inputs if condition: out = self.if_branch(*inputs) else: out = self.else_branch(*inputs) output_storage[0] = out def foo(val1, val2): if val1: out = val2 * 2 else: out = val2 * 3 return out val1 = scalar() val2 = scalar() x = scalar() fgraph1 = FunctionGraph([x], [x * 2]) fgraph2 = FunctionGraph([x], [x * 3]) out = IfElse(fgraph1, fgraph2)(val1, val2) fgraph_inner = FunctionGraph([val1, val2], [out]) x = dvector() y = dvector() Elemwise(scalar_op=fgraph_inner)(x, y) # Can we implement grad for this? class DoWhile(Op): __props__ = ("inner",) def __init__(self, inner: FunctionGraph): assert inner.inputs == inner.outputs[1:] self.inner = inner def make_node(self, *inputs): return Apply([*inputs], self.inner.inputs) def perform(self, inputs, output_storage): while True: condition, *inputs = self.inner(*inputs) if not condition: break output_storage[0] = inputs x = scalar() y = scalar() z = scalar() x_out = 2 * x + z y_out = 2 * x * y z_out = z continue_ = le(x_out + y_out, 100) inner = FunctionGraph([x, y, z], [continue_, x_out, y_out, z_out]) x = scalar() y = scalar() x_out, y_out = DoWhile(inner)(x, y) x = scalar() y = scalar() z = scalar() idx = 0 y_out = 2 * x[idx] x_out = set_subtensor(x[idx], y_out) idx_out = idx + 1 continue_ = le(idx_out, x.shape[0]) inner = FunctionGraph([x, y, z], [continue_, x_out, y_out, z_out]) x = scalar() y = scalar() x_out, y_out = DoWhile(inner)(x, y) class While(Op): __props__ = ("inner",) def __init__(self, inner: FunctionGraph): assert inner.inputs == inner.outputs[1:] self.inner = inner def make_node(self, first_condition, *inputs): return Apply([first_condition, *inputs], self.inner.inputs) def perform(self, inputs, output_storage): condition, *inputs = inputs while condition: condition, *inputs = self.inner(*inputs) output_storage[0] = inputs ``` ### Reverse mode autodiff with loops Problem: What properties must a loop have so that we can compute reverse mode autodiff for it? Taking ideas from https://github.com/GiggleLiu/NiLang.jl Let's say a loop is: - pre-condition - state update - post-condition ie ```python state = init_state if pre_conditon(state): while True: state = update_state(state) if post_conditon(state): break ``` I think it probably needs to be reversible. Something like: A loop `loop1 = (pre, update, post)` is called reversible if we can find a second loop `loop2 = (pre', update', post')` such that $loop1(loop2(loop1(init_state))) = loop1(init_state)$ for all states `init_state`. I'm not 100% sure, but I think that might be the property we need for reverse mode autodiff. So we could allow arbitrary loops, but have loops where we know that they are reversible, either because a user supplied the reverse loop or because we can know it for some other reason (because it stores all intermediate states in the `state` variable). For those loops we could then generate rev mode derivatives. Forward derivatives should be much easier anyway, they just depend on the derivatives of `update_state`. #### Example ```python class Loop(Op): def __init__( pre_condition: FunctionGraph, update: FunctionGraph, post_condition: FunctionGraph, reverse: Optional[Loop] = None ): self._state_types = pre_condition.inputs # validate that # pre_condition: state_types -> [bool] # update: state_types -> state_types # post_condition: state_types -> [bool] # reverse: Optional[Loop[state_types=state_types]] self.pre_condition = pre_condition self.update = update self.post_condition = post_condition self.reverse = reverse def make_node(self, *state): # assert [item.type for item in state] == self._state_types return Apply(self._state_types, self._state_types) def perform(self, node, inputs, output_storage): state = inputs if self.pre_condition(*state): while True: state = self.update(*state) if self.post_condition(*state): break output_storage[0] = state def grads(self, ...): if not self.reverse: raise NotImplementedError() ... # We can use this to represent arbitrary loops, even loops that can not easily # transformed to do reverse mode autodiff. # But we could have a function that automatically produces reverse loops # for special cases, like a loop that remembers all intermediate states: def tracing_while_loop(maxlength, init_states, update, break_func, constant_args): traces = [pt.empty((maxlength,) + state.shape, state.dtype) for state in init_states] idx = pt.iscalar() constant_sym = [const.type() for const in constant_args] state = traces + [ idx, *constant_sym, ] pre = (~break_func(state)) & (maxlength > 0) user_updates = update([tr[idx] for tr in traces], index=idx, constant_args=constant_args) all_updates = [pt.set_subtensor(tr[idx + 1], user_update) for tr, user_update in zip(traces, user_updates)] all_updates += [idx + 1] + constant_args post = (~break_func(state) & (idx < maxlength)) loop_rev = ... # Loop that yields the tr objects in reverse... loop = Loop(FunctionGraph(state, pre), FunctionGraph(state, all_updates), FunctionGraph(state, post), reverse=loop_rev) return loop(init_state) # Maybe the loop should be this instead? state = initial_state assert pre_condition(state) while not post_condition(state): state = update(state) assert not pre_condition(state) # reverse state = final_state assert post_condition(state) while not pre_condition(state): state = rev_update(state) assert not post_condition(state) state = initial_state assert pre_condition(state) if not post_condition(state): while True: state = update(state) assert not pre_condition(state) if post_condition(state): break ``` ## Dims ## EGG https://docs.rs/egg/latest/egg/tutorials/index.html