# Signatures of Ops (minikanren) After discussing ideas about Op signatures as kanren expressions with Max for a while, I thought I'd try to write down a couple of thoughts about typing in general in pytensor. So this isn't exactly a proper writeup, just a hopefully understandable mess. :-) Let's start with a compiled function. We have strong typing for those, so we can always tell what the type of such a function is: ```python func = pytensor.function(...) input_types = func.maker.fgraph.inputs output_types = func.maker.fgraph.outputs ``` So `func` is a function from instances of `input_types` to instances of `output_types`. For our purpose here I think there isn't really a difference between a compiled function and an `Apply` instance (ie a node). (It does also stores input *variables*, which a compiled function doesn't but let's ignore that right now). Each of those simply represent a function from instances of `node.itypes` to `node.otypes`. So what is an `Op`? An `Op` (the Op instance, not the class) does a couple of things, but for now we care about the `make_node` method. It is a function that takes types as input (really variables, but again, let's ignore that) and returns an `Apply` instance. So we can think of the `Op` as a function that maps types to functions. Strongly typed languages often have constucts like these. For instance in rust we can define generic functions (ignoring details...): ```rust fn add<T>(a: T, b: T) -> T { a + b } ``` It makes sense to think of `add` not as a function that returns a number, but a function that takes a type as input and returns a function. So you could say that `add(f64)` is a function `(f64, f64) -> f64`. (I guess Haskell people might have more to say about this...) So we can also ask what the type of `add` is. Namely `T -> ((T, T) -> T)`. Often this is then combined with type inference. I don't actually have to say `add(f64)(1., 2.)` (or actually `add::<f64>(1., 2.)` in rust syntax), but I can just write `add(1., 2.)`, and the compiler will figure out what `T` I must have meant. It seems to me that this is exactly what an `Op` in pytensor is doing. Only that we don't have any representation for what the type of an Op is. For this part, we depend purely on duck typing: we have a function (`make_node`) that we pass input types, and it returns the `Apply` instance, or an error. Now, I'm not 100% sure yet if we need to know the type of an Op, or if duck typing is enough. But if we wanted, could we store it so that other parts of the library can understand that type? What can they do with this information? Does this simplify the type inference, or allow us to prevent errors? Or simplify the `make_node` mess with its static shape info and in the future possibly static dim info? So how could we store this if that is what we want to do, and how would we then implement type inference? Let's say we want to write down the type of a 1d concatenation op. (And for fun I'm going to assume we have dimensions as first class objects already...) We could for instance write it down like this (I think there are other choices as well): ```python # type_vars (corresponds to the T in the add example above) # or the type arguments to the generic function a = TypeVar() b = TypeVar() a_dtype = TypeVar() b_dtype = TypeVar() # Where clause # ie conditions that the input types must satisfy where_clause = [ isinstance(a, Dim), isinstance(b, Dim), isinstance(a_dtype, DType), isinstance(b_dtype, DType), compatible(a_dtype, b_dtype), ] # Input types of the returned function input_types = [ TensorType(dims=(a,), dtype=a_dtype), TensorType(dims=(b,), dtype=b_dtype), ] # Output types of the returned function output_types = [ TensorType(dims=(ConcatDim([a, b]), PromotedDType([a_dtype, b_dtype])) ] signature = Signature( type_vars=[a, b, a_dtype, b_dtype], inputs=input_types, outputs=output_types, ) ``` or to make it a bit shorter: ``` ( (dim a, dim b, dtype x, dtype y) -> (Tensor((a,), x), Tensor((b,), y)) -> Tensor((ConcatDim([a, b]),), PromotedDType([x, y])) ) ``` Note, that we need functions like `ConcatDim` and `PromotedDType`... For other Ops we might also need to deal with more complicated type variables, that might represent lists of dimensions for instance... Max also has a couple of string representations: Interactions ------------ * `(d),(d)->()` - dot product * `(m,n),(n,p)->(m,p)` - matrix multiplication, Note that `(m,n?),(n,p?)->(m?,p?)` is not supported Reference to intermediate Axis ------------------------------ The `.k.` token skips `k` dims, `...` skips any number of dims, can only be used once in the formula * `(M,.k.),(J,.k.)->(J,.k.)` - take_along_axis with dim=-k * `(M,.1.),(J,.1.)->(J,.1.)` - take_along_axis with dim=-2, so 1 dim is skipped at the end The numpy implementation of take_along_axis requires number of dims to be the same, but allows broadcasting * `(d)->()` reduction over the -1 axis * `(d,...)->(...)` reduction over the 0 axis * `(.1.,d,...)->(.1.,...)` reduction over the 1 axis * `(.2.,d,...)->(.2.,...)` reduction over the 2 axis * `(.2.,d,...,k,.1.)->(.2.,...,.1.)` reduction over the 2 and -2 axis Static Shapes ------------- Sometimes you know in advance the size of input or output dimension * `(2)->()` - the last core dim is strictly 2 * `(2,.2.)->()` - the `-3` core dim is strictly 2 Broadcasting ------------ By default Signature assumes no broadcasting, to make signature broadcast, prepend it with `+` (shapes broadcast) or `=` (shapes are strict equal) * `+(d)->()` - reduction over the -1 axis, now it represents Sum(-1) Operator signature * `=(d),(d)->()` - dot product that broadcasts to arbitrary dimensions * `+(),()->()` elemwise that broadcasts to arbitrary dimensions * `=(),()->()` elemwise that works on arbitrary dimensions but requires all to match * `=(),(),()->(3,)` - stack operation Sometimes Ops may not support broadcasting to more than, e.g. 1 or 2 dimensions. In this case signature is specified like this * `+2(d),(d)->()` in case of regular broadcasting * `=2(d),(d)->()` in case of strict broadcasting We could then think about using minikanren to do automatic type inference. The TypeVar instances would be represented as `var()`, and we could then ask for a solution given the input types. Random note: This whole thing reminds me a lot of this: https://rust-lang.github.io/chalk/book/what_is_chalk.html