# Broadcasting in aesara
I'm doing my best here to try and summarize what's going on in #1089 from my
point of view, and trying to figure out what broadcasting semantics we should
use.
## How theano handled broadcasting
I think there has been some confusion about how exactly broadcasting used to
work in theano, not least of all because this wasn't all that well documented.
I'll explain my understanding of this here.
Theano broadcasting was designed so that in a given graph it is always
possible to know statically where broadcasting is happening and what is being
broadcasted. This was explicitly stated in the
[docs](https://github.com/Theano/Theano/blob/master/doc/tutorial/broadcasting.txt#L35)
I'll call this property of a graph (independent on *how* it is achieved)
"static broadcasting". If we do not know from the graph alone where
broadcasting is happening, I'll call that "dynamic broadcasting".
So how did theano achieve static broadcasting? It deviated from numpy
semantics, and introduced as part of the type of a TensorVariable an attribute
`broadcastable`. This specifies for each axis of the Variable if broadcasting
is allowed to occur in that axis.
If the flag is `True` for an axis, this means that the only valid length of
that axis was 1, and that this axis could be broadcasted in the graph if
necessary.
If the flag is `False`, the length of that axis could be arbitrary (including
1), but theano would *never broadcast that axis*.
This means however that theano did not always have the same behavior as NumPy.
If we compute the sum of two TensorVariables that are marked as non-broadcastable,
theano would ensure that they are never broadcasted during graph evaluation.
In cases where NumPy would in fact broadcast, it would raise an exception
saying that their shapes are not compatible:
```python
x = tt.dvector("x")
y = tt.dvector("y")
assert not x.broadcastable[0]
assert not y.broadcastable[0]
z = x + y
z.eval({x: np.zeros(1), y: np.zeros(10)})
# ValueError
```
There was an explicit check in the [elemwise
code](https://github.com/Theano/Theano/blob/master/theano/tensor/elemwise.py#L722)
to make sure that the elemwise op does not in fact broadcast those two arrays,
so that the rewrites and gradients can rely on the fact that no unexpected
broadcasting is happening.
This was reported as a bug
[here](https://github.com/aesara-devs/aesara/issues/335), and the check
then removed [here](https://github.com/aesara-devs/aesara/pull/928),
but I think originally this was a conscious design decision by the theano
developers, not a bug in the first place. (And a bad error message as well, but
that's a different topic...)
I think at that time the reasons for that check and the implications of their
removal were not commonly understood and not discussed at all. (Not trying to
blame anyone here, this stuff happens)
## Why anyone might want static broadcasting
So the theano developers clearly understood that they are deviating from NumPy,
why would they (or anyone) have wanted static broadcasting in the first place?
We stumbled into one reason for this soon after we removed the broadcasting
check from elemwise: Gradients depend on whether broadcasting happened or not.
Let's say we have
```python
x = at.dvector()
y = at.dvector()
z = (x + y).sum()
dz = at.grad(z, x)
```
Then dz will be `ones_like(x)` if `x` was not broadcasted. But if it was,
each value of `x` is used in `z` several times, one time for each value in `y`,
so the gradient should be `ones_like(x).sum(keepdims=True)`.
The associated [issue](https://github.com/aesara-devs/aesara/issues/1089)
contains long discussions about different options of how to represent this
additional complexity in the graph. There clearly are ways to do this,
and I don't want to go through all the options here right now, but I think (?)
everyone agrees that they introduce significant additional complexity into the
graph and make rewrites more difficult.
Just to illustrate that additional complexit, here is the graph of computing
the gradient of the sum of 10 matrices with the current work-in-progress
[solution](https://github.com/aesara-devs/aesara/pull/1260):
<details>
```python
inputs = [at.dmatrix() for _ in range(10)]
z = sum(inputs).sum()
dz = at.grad(z, inputs[0])
aesara.dprint(dz)
# output
if{} [id A]
|TensorFromScalar [id B]
| |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id C]
| |ScalarFromTensor [id D]
| | |Subtensor{int64} [id E]
| | |Shape [id F]
| | | |<TensorType(float64, (None, None))> [id G]
| | |ScalarConstant{1} [id H]
| |ScalarFromTensor [id I]
| |Subtensor{int64} [id J]
| |Shape [id K]
| | |Elemwise{add,no_inplace} [id L]
| | |InplaceDimShuffle{x,x} [id M]
| | | |TensorConstant{0} [id N]
| | |<TensorType(float64, (None, None))> [id G]
| |ScalarConstant{1} [id O]
|InplaceDimShuffle{0,x} [id P]
| |Sum{axis=[1], acc_dtype=float64} [id Q]
| |if{} [id R]
| |TensorFromScalar [id S]
| | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id T]
| | |ScalarFromTensor [id U]
| | | |Subtensor{int64} [id V]
| | | |Shape [id F]
| | | |ScalarConstant{0} [id W]
| | |ScalarFromTensor [id X]
| | |Subtensor{int64} [id Y]
| | |Shape [id K]
| | |ScalarConstant{0} [id Z]
| |InplaceDimShuffle{x,0} [id BA]
| | |Sum{axis=[0], acc_dtype=float64} [id BB]
| | |if{} [id BC]
| | |TensorFromScalar [id BD]
| | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id BE]
| | | |ScalarFromTensor [id BF]
| | | | |Subtensor{int64} [id BG]
| | | | |Shape [id BH]
| | | | | |Elemwise{add,no_inplace} [id L]
| | | | |ScalarConstant{1} [id BI]
| | | |ScalarFromTensor [id BJ]
| | | |Subtensor{int64} [id BK]
| | | |Shape [id BL]
| | | | |Elemwise{add,no_inplace} [id BM]
| | | | |Elemwise{add,no_inplace} [id L]
| | | | |<TensorType(float64, (None, None))> [id BN]
| | | |ScalarConstant{1} [id BO]
| | |InplaceDimShuffle{0,x} [id BP]
| | | |Sum{axis=[1], acc_dtype=float64} [id BQ]
| | | |if{} [id BR]
| | | |TensorFromScalar [id BS]
| | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id BT]
| | | | |ScalarFromTensor [id BU]
| | | | | |Subtensor{int64} [id BV]
| | | | | |Shape [id BH]
| | | | | |ScalarConstant{0} [id BW]
| | | | |ScalarFromTensor [id BX]
| | | | |Subtensor{int64} [id BY]
| | | | |Shape [id BL]
| | | | |ScalarConstant{0} [id BZ]
| | | |InplaceDimShuffle{x,0} [id CA]
| | | | |Sum{axis=[0], acc_dtype=float64} [id CB]
| | | | |if{} [id CC]
| | | | |TensorFromScalar [id CD]
| | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id CE]
| | | | | |ScalarFromTensor [id CF]
| | | | | | |Subtensor{int64} [id CG]
| | | | | | |Shape [id CH]
| | | | | | | |Elemwise{add,no_inplace} [id BM]
| | | | | | |ScalarConstant{1} [id CI]
| | | | | |ScalarFromTensor [id CJ]
| | | | | |Subtensor{int64} [id CK]
| | | | | |Shape [id CL]
| | | | | | |Elemwise{add,no_inplace} [id CM]
| | | | | | |Elemwise{add,no_inplace} [id BM]
| | | | | | |<TensorType(float64, (None, None))> [id CN]
| | | | | |ScalarConstant{1} [id CO]
| | | | |InplaceDimShuffle{0,x} [id CP]
| | | | | |Sum{axis=[1], acc_dtype=float64} [id CQ]
| | | | | |if{} [id CR]
| | | | | |TensorFromScalar [id CS]
| | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id CT]
| | | | | | |ScalarFromTensor [id CU]
| | | | | | | |Subtensor{int64} [id CV]
| | | | | | | |Shape [id CH]
| | | | | | | |ScalarConstant{0} [id CW]
| | | | | | |ScalarFromTensor [id CX]
| | | | | | |Subtensor{int64} [id CY]
| | | | | | |Shape [id CL]
| | | | | | |ScalarConstant{0} [id CZ]
| | | | | |InplaceDimShuffle{x,0} [id DA]
| | | | | | |Sum{axis=[0], acc_dtype=float64} [id DB]
| | | | | | |if{} [id DC]
| | | | | | |TensorFromScalar [id DD]
| | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id DE]
| | | | | | | |ScalarFromTensor [id DF]
| | | | | | | | |Subtensor{int64} [id DG]
| | | | | | | | |Shape [id DH]
| | | | | | | | | |Elemwise{add,no_inplace} [id CM]
| | | | | | | | |ScalarConstant{1} [id DI]
| | | | | | | |ScalarFromTensor [id DJ]
| | | | | | | |Subtensor{int64} [id DK]
| | | | | | | |Shape [id DL]
| | | | | | | | |Elemwise{add,no_inplace} [id DM]
| | | | | | | | |Elemwise{add,no_inplace} [id CM]
| | | | | | | | |<TensorType(float64, (None, None))> [id DN]
| | | | | | | |ScalarConstant{1} [id DO]
| | | | | | |InplaceDimShuffle{0,x} [id DP]
| | | | | | | |Sum{axis=[1], acc_dtype=float64} [id DQ]
| | | | | | | |if{} [id DR]
| | | | | | | |TensorFromScalar [id DS]
| | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id DT]
| | | | | | | | |ScalarFromTensor [id DU]
| | | | | | | | | |Subtensor{int64} [id DV]
| | | | | | | | | |Shape [id DH]
| | | | | | | | | |ScalarConstant{0} [id DW]
| | | | | | | | |ScalarFromTensor [id DX]
| | | | | | | | |Subtensor{int64} [id DY]
| | | | | | | | |Shape [id DL]
| | | | | | | | |ScalarConstant{0} [id DZ]
| | | | | | | |InplaceDimShuffle{x,0} [id EA]
| | | | | | | | |Sum{axis=[0], acc_dtype=float64} [id EB]
| | | | | | | | |if{} [id EC]
| | | | | | | | |TensorFromScalar [id ED]
| | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id EE]
| | | | | | | | | |ScalarFromTensor [id EF]
| | | | | | | | | | |Subtensor{int64} [id EG]
| | | | | | | | | | |Shape [id EH]
| | | | | | | | | | | |Elemwise{add,no_inplace} [id DM]
| | | | | | | | | | |ScalarConstant{1} [id EI]
| | | | | | | | | |ScalarFromTensor [id EJ]
| | | | | | | | | |Subtensor{int64} [id EK]
| | | | | | | | | |Shape [id EL]
| | | | | | | | | | |Elemwise{add,no_inplace} [id EM]
| | | | | | | | | | |Elemwise{add,no_inplace} [id DM]
| | | | | | | | | | |<TensorType(float64, (None, None))> [id EN]
| | | | | | | | | |ScalarConstant{1} [id EO]
| | | | | | | | |InplaceDimShuffle{0,x} [id EP]
| | | | | | | | | |Sum{axis=[1], acc_dtype=float64} [id EQ]
| | | | | | | | | |if{} [id ER]
| | | | | | | | | |TensorFromScalar [id ES]
| | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id ET]
| | | | | | | | | | |ScalarFromTensor [id EU]
| | | | | | | | | | | |Subtensor{int64} [id EV]
| | | | | | | | | | | |Shape [id EH]
| | | | | | | | | | | |ScalarConstant{0} [id EW]
| | | | | | | | | | |ScalarFromTensor [id EX]
| | | | | | | | | | |Subtensor{int64} [id EY]
| | | | | | | | | | |Shape [id EL]
| | | | | | | | | | |ScalarConstant{0} [id EZ]
| | | | | | | | | |InplaceDimShuffle{x,0} [id FA]
| | | | | | | | | | |Sum{axis=[0], acc_dtype=float64} [id FB]
| | | | | | | | | | |if{} [id FC]
| | | | | | | | | | |TensorFromScalar [id FD]
| | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id FE]
| | | | | | | | | | | |ScalarFromTensor [id FF]
| | | | | | | | | | | | |Subtensor{int64} [id FG]
| | | | | | | | | | | | |Shape [id FH]
| | | | | | | | | | | | | |Elemwise{add,no_inplace} [id EM]
| | | | | | | | | | | | |ScalarConstant{1} [id FI]
| | | | | | | | | | | |ScalarFromTensor [id FJ]
| | | | | | | | | | | |Subtensor{int64} [id FK]
| | | | | | | | | | | |Shape [id FL]
| | | | | | | | | | | | |Elemwise{add,no_inplace} [id FM]
| | | | | | | | | | | | |Elemwise{add,no_inplace} [id EM]
| | | | | | | | | | | | |<TensorType(float64, (None, None))> [id FN]
| | | | | | | | | | | |ScalarConstant{1} [id FO]
| | | | | | | | | | |InplaceDimShuffle{0,x} [id FP]
| | | | | | | | | | | |Sum{axis=[1], acc_dtype=float64} [id FQ]
| | | | | | | | | | | |if{} [id FR]
| | | | | | | | | | | |TensorFromScalar [id FS]
| | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id FT]
| | | | | | | | | | | | |ScalarFromTensor [id FU]
| | | | | | | | | | | | | |Subtensor{int64} [id FV]
| | | | | | | | | | | | | |Shape [id FH]
| | | | | | | | | | | | | |ScalarConstant{0} [id FW]
| | | | | | | | | | | | |ScalarFromTensor [id FX]
| | | | | | | | | | | | |Subtensor{int64} [id FY]
| | | | | | | | | | | | |Shape [id FL]
| | | | | | | | | | | | |ScalarConstant{0} [id FZ]
| | | | | | | | | | | |InplaceDimShuffle{x,0} [id GA]
| | | | | | | | | | | | |Sum{axis=[0], acc_dtype=float64} [id GB]
| | | | | | | | | | | | |if{} [id GC]
| | | | | | | | | | | | |TensorFromScalar [id GD]
| | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id GE]
| | | | | | | | | | | | | |ScalarFromTensor [id GF]
| | | | | | | | | | | | | | |Subtensor{int64} [id GG]
| | | | | | | | | | | | | | |Shape [id GH]
| | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id FM]
| | | | | | | | | | | | | | |ScalarConstant{1} [id GI]
| | | | | | | | | | | | | |ScalarFromTensor [id GJ]
| | | | | | | | | | | | | |Subtensor{int64} [id GK]
| | | | | | | | | | | | | |Shape [id GL]
| | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id GM]
| | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id FM]
| | | | | | | | | | | | | | |<TensorType(float64, (None, None))> [id GN]
| | | | | | | | | | | | | |ScalarConstant{1} [id GO]
| | | | | | | | | | | | |InplaceDimShuffle{0,x} [id GP]
| | | | | | | | | | | | | |Sum{axis=[1], acc_dtype=float64} [id GQ]
| | | | | | | | | | | | | |if{} [id GR]
| | | | | | | | | | | | | |TensorFromScalar [id GS]
| | | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id GT]
| | | | | | | | | | | | | | |ScalarFromTensor [id GU]
| | | | | | | | | | | | | | | |Subtensor{int64} [id GV]
| | | | | | | | | | | | | | | |Shape [id GH]
| | | | | | | | | | | | | | | |ScalarConstant{0} [id GW]
| | | | | | | | | | | | | | |ScalarFromTensor [id GX]
| | | | | | | | | | | | | | |Subtensor{int64} [id GY]
| | | | | | | | | | | | | | |Shape [id GL]
| | | | | | | | | | | | | | |ScalarConstant{0} [id GZ]
| | | | | | | | | | | | | |InplaceDimShuffle{x,0} [id HA]
| | | | | | | | | | | | | | |Sum{axis=[0], acc_dtype=float64} [id HB]
| | | | | | | | | | | | | | |if{} [id HC]
| | | | | | | | | | | | | | |TensorFromScalar [id HD]
| | | | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id HE]
| | | | | | | | | | | | | | | |ScalarFromTensor [id HF]
| | | | | | | | | | | | | | | | |Subtensor{int64} [id HG]
| | | | | | | | | | | | | | | | |Shape [id HH]
| | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id GM]
| | | | | | | | | | | | | | | | |ScalarConstant{1} [id HI]
| | | | | | | | | | | | | | | |ScalarFromTensor [id HJ]
| | | | | | | | | | | | | | | |Subtensor{int64} [id HK]
| | | | | | | | | | | | | | | |Shape [id HL]
| | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id HM]
| | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id GM]
| | | | | | | | | | | | | | | | |<TensorType(float64, (None, None))> [id HN]
| | | | | | | | | | | | | | | |ScalarConstant{1} [id HO]
| | | | | | | | | | | | | | |InplaceDimShuffle{0,x} [id HP]
| | | | | | | | | | | | | | | |Sum{axis=[1], acc_dtype=float64} [id HQ]
| | | | | | | | | | | | | | | |if{} [id HR]
| | | | | | | | | | | | | | | |TensorFromScalar [id HS]
| | | | | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id HT]
| | | | | | | | | | | | | | | | |ScalarFromTensor [id HU]
| | | | | | | | | | | | | | | | | |Subtensor{int64} [id HV]
| | | | | | | | | | | | | | | | | |Shape [id HH]
| | | | | | | | | | | | | | | | | |ScalarConstant{0} [id HW]
| | | | | | | | | | | | | | | | |ScalarFromTensor [id HX]
| | | | | | | | | | | | | | | | |Subtensor{int64} [id HY]
| | | | | | | | | | | | | | | | |Shape [id HL]
| | | | | | | | | | | | | | | | |ScalarConstant{0} [id HZ]
| | | | | | | | | | | | | | | |InplaceDimShuffle{x,0} [id IA]
| | | | | | | | | | | | | | | | |Sum{axis=[0], acc_dtype=float64} [id IB]
| | | | | | | | | | | | | | | | |if{} [id IC]
| | | | | | | | | | | | | | | | |TensorFromScalar [id ID]
| | | | | | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id IE]
| | | | | | | | | | | | | | | | | |ScalarFromTensor [id IF]
| | | | | | | | | | | | | | | | | | |Subtensor{int64} [id IG]
| | | | | | | | | | | | | | | | | | |Shape [id IH]
| | | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id HM]
| | | | | | | | | | | | | | | | | | |ScalarConstant{1} [id II]
| | | | | | | | | | | | | | | | | |ScalarFromTensor [id IJ]
| | | | | | | | | | | | | | | | | |Subtensor{int64} [id IK]
| | | | | | | | | | | | | | | | | |Shape [id IL]
| | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id IM]
| | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id HM]
| | | | | | | | | | | | | | | | | | |<TensorType(float64, (None, None))> [id IN]
| | | | | | | | | | | | | | | | | |ScalarConstant{1} [id IO]
| | | | | | | | | | | | | | | | |InplaceDimShuffle{0,x} [id IP]
| | | | | | | | | | | | | | | | | |Sum{axis=[1], acc_dtype=float64} [id IQ]
| | | | | | | | | | | | | | | | | |if{} [id IR]
| | | | | | | | | | | | | | | | | |TensorFromScalar [id IS]
| | | | | | | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id IT]
| | | | | | | | | | | | | | | | | | |ScalarFromTensor [id IU]
| | | | | | | | | | | | | | | | | | | |Subtensor{int64} [id IV]
| | | | | | | | | | | | | | | | | | | |Shape [id IH]
| | | | | | | | | | | | | | | | | | | |ScalarConstant{0} [id IW]
| | | | | | | | | | | | | | | | | | |ScalarFromTensor [id IX]
| | | | | | | | | | | | | | | | | | |Subtensor{int64} [id IY]
| | | | | | | | | | | | | | | | | | |Shape [id IL]
| | | | | | | | | | | | | | | | | | |ScalarConstant{0} [id IZ]
| | | | | | | | | | | | | | | | | |InplaceDimShuffle{x,0} [id JA]
| | | | | | | | | | | | | | | | | | |Sum{axis=[0], acc_dtype=float64} [id JB]
| | | | | | | | | | | | | | | | | | |if{} [id JC]
| | | | | | | | | | | | | | | | | | |TensorFromScalar [id JD]
| | | | | | | | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id JE]
| | | | | | | | | | | | | | | | | | | |ScalarFromTensor [id JF]
| | | | | | | | | | | | | | | | | | | | |Subtensor{int64} [id JG]
| | | | | | | | | | | | | | | | | | | | |Shape [id JH]
| | | | | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id IM]
| | | | | | | | | | | | | | | | | | | | |ScalarConstant{1} [id JI]
| | | | | | | | | | | | | | | | | | | |ScalarFromTensor [id JJ]
| | | | | | | | | | | | | | | | | | | |Subtensor{int64} [id JK]
| | | | | | | | | | | | | | | | | | | |Shape [id JL]
| | | | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id JM]
| | | | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id IM]
| | | | | | | | | | | | | | | | | | | | |<TensorType(float64, (None, None))> [id JN]
| | | | | | | | | | | | | | | | | | | |ScalarConstant{1} [id JO]
| | | | | | | | | | | | | | | | | | |InplaceDimShuffle{0,x} [id JP]
| | | | | | | | | | | | | | | | | | | |Sum{axis=[1], acc_dtype=float64} [id JQ]
| | | | | | | | | | | | | | | | | | | |if{} [id JR]
| | | | | | | | | | | | | | | | | | | |TensorFromScalar [id JS]
| | | | | | | | | | | | | | | | | | | | |Composite{AND(EQ(i0, 1), NEQ(i1, 1))} [id JT]
| | | | | | | | | | | | | | | | | | | | |ScalarFromTensor [id JU]
| | | | | | | | | | | | | | | | | | | | | |Subtensor{int64} [id JV]
| | | | | | | | | | | | | | | | | | | | | |Shape [id JH]
| | | | | | | | | | | | | | | | | | | | | |ScalarConstant{0} [id JW]
| | | | | | | | | | | | | | | | | | | | |ScalarFromTensor [id JX]
| | | | | | | | | | | | | | | | | | | | |Subtensor{int64} [id JY]
| | | | | | | | | | | | | | | | | | | | |Shape [id JL]
| | | | | | | | | | | | | | | | | | | | |ScalarConstant{0} [id JZ]
| | | | | | | | | | | | | | | | | | | |InplaceDimShuffle{x,0} [id KA]
| | | | | | | | | | | | | | | | | | | | |Sum{axis=[0], acc_dtype=float64} [id KB]
| | | | | | | | | | | | | | | | | | | | |Elemwise{second} [id KC]
| | | | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id JM]
| | | | | | | | | | | | | | | | | | | | |InplaceDimShuffle{x,x} [id KD]
| | | | | | | | | | | | | | | | | | | | |Elemwise{second,no_inplace} [id KE]
| | | | | | | | | | | | | | | | | | | | |Sum{acc_dtype=float64} [id KF]
| | | | | | | | | | | | | | | | | | | | | |Elemwise{add,no_inplace} [id JM]
| | | | | | | | | | | | | | | | | | | | |TensorConstant{1.0} [id KG]
| | | | | | | | | | | | | | | | | | | |Elemwise{second} [id KC]
| | | | | | | | | | | | | | | | | | |if{} [id JR]
| | | | | | | | | | | | | | | | | |if{} [id JC]
| | | | | | | | | | | | | | | | |if{} [id IR]
| | | | | | | | | | | | | | | |if{} [id IC]
| | | | | | | | | | | | | | |if{} [id HR]
| | | | | | | | | | | | | |if{} [id HC]
| | | | | | | | | | | | |if{} [id GR]
| | | | | | | | | | | |if{} [id GC]
| | | | | | | | | | |if{} [id FR]
| | | | | | | | | |if{} [id FC]
| | | | | | | | |if{} [id ER]
| | | | | | | |if{} [id EC]
| | | | | | |if{} [id DR]
| | | | | |if{} [id DC]
| | | | |if{} [id CR]
| | | |if{} [id CC]
| | |if{} [id BR]
| |if{} [id BC]
|if{} [id R]
```
</details>
There is also I think however a deeper performance problem that is introduced
by not knowing the broadcasting pattern statically. This happens both in the
derivative code and in the forward code, I'll focus here on the forward code
however, because I think there it is easier to explain.
Let's use as an example a simple elemwise operation `exp(x * y)`, where `x`
and `y` are vectors. If we are using static broadcasting we know exactly if
there is any broadcasting when we compile the function, so we can generate
code under that assumption. In a dynamic broadcasting environment we need
to produce machine code that is capable of dealing with the broadcasting,
and the non-broadcasting case. Let's see how my best attempts for generic
and non-generic code compare (I'm using numba here, but this is really about
the code that llvm is able to produce here, if we did this in C I think the
results would look very similar.):
```python
###############################################
# Specialized function for known broadcasting #
###############################################
@numba.njit
def foo_known_no_broadcast(x, y, out):
(n,) = x.shape
assert x.shape == y.shape
assert x.shape == out.shape
for i in range(n):
out[i] = np.exp(x[i] * y[i])
@numba.njit
def foo_known_broadcast_x(x, y, out):
(n,) = y.shape
assert x.shape == (1,)
assert y.shape == out.shape
x = x[0]
for i in range(n):
out[i] = np.exp(x * y[i])
@numba.njit
def foo_known_broadcast_y(x, y, out):
(n,) = x.shape
assert y.shape == (1,)
assert x.shape == out.shape
y = y[0]
for i in range(n):
out[i] = np.exp(x[i] * y)
#############################
# Different generic options #
#############################
@numba.njit
def foo_generic(x, y, out):
x, y = np.broadcast_arrays(x, y)
(n,) = x.shape
assert x.shape == y.shape
assert x.shape == out.shape
for i in range(n):
out[i] = np.exp(x[i] * y[i])
@numba.njit
def foo_generic2(x, y, out):
(n,) = x.shape
(m,) = y.shape
if n == m:
stride_x = 1
stride_y = 1
elif n == 1:
stride_x = 0
stride_y = 1
elif m == 1:
stride_x = 1
stride_y = 0
else:
assert False
n = max(n, m)
assert out.shape == (n,)
for i in range(n):
out[i] = np.exp(x[i * stride_x] * y[i * stride_y])
@numba.njit
def foo_dispatch(x, y, out):
(n,) = x.shape
(m,) = y.shape
if m == 1:
foo_known_broadcast_y(x, y, out)
elif n == 1:
foo_known_broadcast_x(x, y)
else:
foo_known_no_broadcast(x, y, out)
```
So here we have specialized implementations for `elemwise(exp(x * y))` where we
assume we already know if we are broadcasting, and we also have two different
generic versions that compute the same thing, but work regardless of the
broadcasting that's actually happening. We also have a third generic
implementation (`foo_dispatch`) that just calls the specialized versions in the
different broadcasting scenarios.
So how does their performance compare?
```python
N = 100_000
x = np.random.randn(N)
y_no_bc = np.random.randn(N)
y_bc = np.random.randn(1)
out = np.random.randn(N)
# We don't broadcast anything
%timeit foo_known_no_broadcast(x, y_no_bc, out)
144 µs ± 650 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit foo_generic(x, y_no_bc, out)
533 µs ± 384 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit foo_generic2(x, y_no_bc, out)
146 µs ± 221 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit foo_dispatch(x, y_no_bc, out)
143 µs ± 563 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# We broadcast y
%timeit foo_known_broadcast_y(x, y_bc, out)
140 µs ± 439 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit foo_generic(x, y_bc, out)
530 µs ± 2.24 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit foo_generic2(x, y_bc, out)
556 µs ± 481 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit foo_dispatch(x, y_bc, out)
140 µs ± 297 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
The two truly generic versions `foo_generic` and `foo_generic2` are always
slower by a factor of ~3.5 for at least one of the scenarios.
Only `foo_dispatch` matches the performance of the static implementations. So
based on this, `foo_dispatch` looks pretty nice, it matches the performance of
both the specific cases. We do run into issues however if the number of inputs
grows: For two inputs there are 3 different options: x is broadcasted, y is
broadcasted and nothing is broadcasted. But if we have `n` vector inputs for a
more general elemwise, we have $2^n - 1$ different cases. $n=10$ would I think
not be all that unusual for elemwise ops in practice, but we just can't compile
1023 loops for each of those cases.
So to summarize: If we know where broadcasting happens statically, we can produce
much simpler gradient graphs, and we can generate faster code.
## My personal (tentative) conclusion
I think the downsides to continuing to use something similar to theano
broadcasting are mainly in three areas:
- Usability can be impacted by the broadcasting differences between
theano/aesara and NumPy. I think better error messages and better
documentation for those differences should help a lot with this though.
- We need additional checks in the backends that make sure no unintended
broadcasting is happening.
- We have to continue to deal with broadcastable flags of some kind in the code
base.
I tend to think however that the advantages outway the problems:
- We keep the graphs nice and simple, which also helps a lot in all future debugging and rewrites.
- We have more options to improve performance by taking advantage of known broadcasting patterns.