# Design changes in pytensor to think about
## Fusion optimizer
https://github.com/pymc-devs/pytensor/pull/121
- Are we fusing too much now (ignoring bugs)?
- Should we try to restrict to Elemwise subgraphs that share something other than constants?
- Should we try to only fuse "vectorizable" Ops (whatever those are, probably nothing from scipy.special)
- Backend specific truncation rules (if on numba, don't merge Ops that are evaluated in object mode)
- Don't run at all in JAX?
## RandomVariable in numba
A PR was merged in numba for Generator support?
We could use that to implement RandomVariables? Seems there is also a stale aesara PR already...
## Function object for numba and Jax
We can get rid of a lot of C specific things.
But how do we hanled updates and shared variables?
It would be great if the compiled function could be called from numba/jax directly.
(some work for shared variables is in nutpie).
- Perhaps have a version without updates only (that raises if there were any)
- Two levels: One that manages input validation and shared variables and the lower level one that requires everything be provided manually
## Compilation times with numba backend
- Really bad right now?
- Mostly because of numba itself? It compiles any intermediate function separately.
Some other things:
- pass no_cpython_wrapper=True where possible.
- Make numba functions global so they can be reused.
- impl mul/add as global functions. Changes to CAReduce necessary? But maybe merge that with Elemwise?
- Add a decorator to funcify: `cache_by_op` or so, that reuses the numba function from a previous funcify call.
## Vectorization in numba
There seems to be an issue in numba that sometimes prevents the llvm autovectorizer from working.
https://numba.discourse.group/t/compilation-pipeline-compile-time-and-vectorization/1716
Maybe inline="always" can help?
## Scan Op redesign?
- Experimenting with "Scalar Scan Op" in https://github.com/pymc-devs/pytensor/pull/174
https://www.sjalander.com/research/pdf/sjalander-tecs2020.pdf
Let's try to implement a new loop op, and see if that can do what we did before? Long-term deprecate the old scan at some point?
## Types / Variables / .type attribute
Max said something?
## numba llvm vs numba code? CAReduce?
Maybe fuse those two very late in the pipeline for numba
## ifelse? It's not actually if-else right now, but a switch in numba
I think we need subgraphs in IfElse?
```python=
out = switch(cond, op1(a ,b), op2(c, d))
# def scalar_out(a, b, c, d):
# _1 = op1(a, b)
# _2 = op2(c, d)
# out = cond ? _1, _2
```
```python=
res = empty(cond.shape)
res = set_subtensor(res[cond], op1(a[cond], b[cond]), op2(c[cond], d[cond]))
res = set_subtensor(res[~cond], op1(a[~cond], b[~cond]), op2(c[~cond], d[~cond]))
```
```python=
class IfElse(Op):
__props__ = ("if_branch", "else_branch")
def __init__(self, if_branch: FunctionGraph, else_branch: FunctionGraph):
self.if_branch = if_branch
self.else_branch = else_branch
# assert same input and output types
def make_node(self, condition, *inputs):
return Apply([condition, *inputs], self.if_branch.outputs)
def perform(self, inputs, output_storage):
condition, *inputs = inputs
if condition:
out = self.if_branch(*inputs)
else:
out = self.else_branch(*inputs)
output_storage[0] = out
def foo(val1, val2):
if val1:
out = val2 * 2
else:
out = val2 * 3
return out
val1 = scalar()
val2 = scalar()
x = scalar()
fgraph1 = FunctionGraph([x], [x * 2])
fgraph2 = FunctionGraph([x], [x * 3])
out = IfElse(fgraph1, fgraph2)(val1, val2)
fgraph_inner = FunctionGraph([val1, val2], [out])
x = dvector()
y = dvector()
Elemwise(scalar_op=fgraph_inner)(x, y)
# Can we implement grad for this?
class DoWhile(Op):
__props__ = ("inner",)
def __init__(self, inner: FunctionGraph):
assert inner.inputs == inner.outputs[1:]
self.inner = inner
def make_node(self, *inputs):
return Apply([*inputs], self.inner.inputs)
def perform(self, inputs, output_storage):
while True:
condition, *inputs = self.inner(*inputs)
if not condition:
break
output_storage[0] = inputs
x = scalar()
y = scalar()
z = scalar()
x_out = 2 * x + z
y_out = 2 * x * y
z_out = z
continue_ = le(x_out + y_out, 100)
inner = FunctionGraph([x, y, z], [continue_, x_out, y_out, z_out])
x = scalar()
y = scalar()
x_out, y_out = DoWhile(inner)(x, y)
x = scalar()
y = scalar()
z = scalar()
idx = 0
y_out = 2 * x[idx]
x_out = set_subtensor(x[idx], y_out)
idx_out = idx + 1
continue_ = le(idx_out, x.shape[0])
inner = FunctionGraph([x, y, z], [continue_, x_out, y_out, z_out])
x = scalar()
y = scalar()
x_out, y_out = DoWhile(inner)(x, y)
class While(Op):
__props__ = ("inner",)
def __init__(self, inner: FunctionGraph):
assert inner.inputs == inner.outputs[1:]
self.inner = inner
def make_node(self, first_condition, *inputs):
return Apply([first_condition, *inputs], self.inner.inputs)
def perform(self, inputs, output_storage):
condition, *inputs = inputs
while condition:
condition, *inputs = self.inner(*inputs)
output_storage[0] = inputs
```
### Reverse mode autodiff with loops
Problem: What properties must a loop have so that we can compute reverse mode autodiff for it?
Taking ideas from https://github.com/GiggleLiu/NiLang.jl
Let's say a loop is:
- pre-condition
- state update
- post-condition
ie
```python
state = init_state
if pre_conditon(state):
while True:
state = update_state(state)
if post_conditon(state):
break
```
I think it probably needs to be reversible. Something like:
A loop `loop1 = (pre, update, post)` is called reversible if we can find a second loop `loop2 = (pre', update', post')` such that
$loop1(loop2(loop1(init_state))) = loop1(init_state)$
for all states `init_state`.
I'm not 100% sure, but I think that might be the property we need for reverse mode autodiff.
So we could allow arbitrary loops, but have loops where we know that they are reversible, either because a user supplied the reverse loop or because we can know it for some other reason (because it stores all intermediate states in the `state` variable). For those loops we could then generate rev mode derivatives.
Forward derivatives should be much easier anyway, they just depend on the derivatives of `update_state`.
#### Example
```python
class Loop(Op):
def __init__(
pre_condition: FunctionGraph,
update: FunctionGraph,
post_condition: FunctionGraph,
reverse: Optional[Loop] = None
):
self._state_types = pre_condition.inputs
# validate that
# pre_condition: state_types -> [bool]
# update: state_types -> state_types
# post_condition: state_types -> [bool]
# reverse: Optional[Loop[state_types=state_types]]
self.pre_condition = pre_condition
self.update = update
self.post_condition = post_condition
self.reverse = reverse
def make_node(self, *state):
# assert [item.type for item in state] == self._state_types
return Apply(self._state_types, self._state_types)
def perform(self, node, inputs, output_storage):
state = inputs
if self.pre_condition(*state):
while True:
state = self.update(*state)
if self.post_condition(*state):
break
output_storage[0] = state
def grads(self, ...):
if not self.reverse:
raise NotImplementedError()
...
# We can use this to represent arbitrary loops, even loops that can not easily
# transformed to do reverse mode autodiff.
# But we could have a function that automatically produces reverse loops
# for special cases, like a loop that remembers all intermediate states:
def tracing_while_loop(maxlength, init_states, update, break_func, constant_args):
traces = [pt.empty((maxlength,) + state.shape, state.dtype) for state in init_states]
idx = pt.iscalar()
constant_sym = [const.type() for const in constant_args]
state = traces + [
idx,
*constant_sym,
]
pre = (~break_func(state)) & (maxlength > 0)
user_updates = update([tr[idx] for tr in traces], index=idx, constant_args=constant_args)
all_updates = [pt.set_subtensor(tr[idx + 1], user_update) for tr, user_update in zip(traces, user_updates)]
all_updates += [idx + 1] + constant_args
post = (~break_func(state) & (idx < maxlength))
loop_rev = ... # Loop that yields the tr objects in reverse...
loop = Loop(FunctionGraph(state, pre), FunctionGraph(state, all_updates), FunctionGraph(state, post), reverse=loop_rev)
return loop(init_state)
# Maybe the loop should be this instead?
state = initial_state
assert pre_condition(state)
while not post_condition(state):
state = update(state)
assert not pre_condition(state)
# reverse
state = final_state
assert post_condition(state)
while not pre_condition(state):
state = rev_update(state)
assert not post_condition(state)
state = initial_state
assert pre_condition(state)
if not post_condition(state):
while True:
state = update(state)
assert not pre_condition(state)
if post_condition(state):
break
```
## Dims
## EGG
https://docs.rs/egg/latest/egg/tutorials/index.html