# Exogenation Design Document ## The Problem Consider the following probabilistic programs: ```python def model(): x = pyro.sample("x", Normal(loc=1., scale=2.)) y = pyro.sample("y", Normal(loc=x, scale=1.)) ``` ```python def exogenated_model(): x_u = pyro.sample("x_u", Normal(0., 1.)) _x = AffineTransform(loc=1., scale=2.)(x_u) x = pyro.sample("x", Delta(_x)) y_u = pyro.sample("y_u", Normal(0., 1.)) _y = AffineTransform(loc=x, scale=1.)(y_u) y = pyro.sample("y", Delta(_y)) ``` These two programs represent the same structural causal model and should be distributionally equivalent for all potential outcomes. Under `chirho`'s current implementation, however, distributions over their potential outcomes are not equivalent. ```python with MultiWorldCounterfactual(first_available_dim=-1): with do(x=torch.tensor(0.0)): model() ``` ```python with MultiWorldCounterfactual(first_available_dim=-1): with do(x=torch.tensor(0.0)): exogenated_model() ``` Only the intervened `exogenated_model` behaves like a standard structural causal model. Indeed, `x_u` and `y_u` remain scalars and are shared across counterfactual worlds. In particular, `x` does not appear as a parameter in the exogenous `y_u` distribution, and only broadcasts against `y_u` in the downstream `AffineTransform`. Thus, `x_u.shape == y_u.shape == ()` while `x.shape == y.shape == (2,)`, and `indices_of(x_u) == indices_of(y_u) == IndexSet({})` while `indices_of(x) == indices_of(y) == IndexSet({x: {0, 1}})`. The intervened `model`, however, independently samples exogenous noise for `y` across the maximum shape of its parameters. That is, it treats batch dimensions arising from intervention the same as any other batch dimension. This mismatch means users must explicitly write `exogenated_model`-style programs when they wish to remain faithful to the standard SCM formulation. Unfortunately, this obfuscates useful structure. For example, it is far more difficult to identify and exploit conjugacy relationships, to symbolically generate excised counterfactual distributions, or to straightforwardly compute likelihoods after conditioning on `Delta` distributed sites. Further, it prevents us from exploiting easy analytic inversions from factual observations to exogenous noise that should be propagated to counterfactuals. ## Desired Interface The user writes only `model`, and we automatically transform it to behave like `exogenated_model` under `MultiWorldCounterfactual` and `do`. We want the user to write a model that 1. preserves symbolic structure allowing for analytic computation of likelihoods and distributional transformations (such as excising). 2. can be forward simulated in a way that results in the expected joint distribution over potential outcomes (i.e. where noise is shared across worlds). 3. admits an *analytic* means of sampling from noise distributions conditional on factual observations, and having those samples propagate to counterfactual worlds. ### Caveats for Desired Interface Unfortunately, much like a directed acyclic Bayesian network is consistent with many SCMs, `model` too can be mapped to many SCMs. Consider this alternative: ```python def exogenated_model2(): x_u = pyro.sample("x_u", Uniform(0., 1.)) _x = Normal(0., 1.).icdf(x_u) _x = AffineTransform(loc=1., scale=2.)(_x) x = pyro.sample("x", Delta(_x)) y_u = pyro.sample("y_u", Uniform(0., 1.)) _y = Normal(0., 1.).icdf(y_u) _y = AffineTransform(loc=x, scale=1.)(_y) y = pyro.sample("y", Delta(_y)) ``` Both `exogenated_model` and `exogenated_model2` are consistent with `model`, differing only in how transformations are partitioned between exogenous noise and structural equations. ### Brief Formalization For a random variable $Y$ with parents $Pa(Y)$, we can write: $$ P_Y = [f_n(Pa(Y), \cdot) \circ \dots \circ f_1(Pa(Y), \cdot) \circ g_m \circ \dots \circ g_1]_{\#} P_\Omega $$ where $f_i$ may depend on endogenous variables $Pa(Y)$ and $g_j$ are exogenous transformations (independent of endogenous variables). The exogenous noise distribution is: $$ P_{U_Y} = [g_m \circ \dots \circ g_1]_{\#} P_\Omega $$ We can freely move transformations between the $g$ sequence (exogenous) and the $f$ sequence (structural equation), provided $g_m$ does not depend on any endogenous variables. Different partitions yield the same $P_Y$ but different $P_{U_Y}$. The practical consequences of different $P_{U_Y}$ are not entirely clear to me yet, but I conjecture that they only matter when using soft-conditioning and trying to ensure a well-conditioned soft posterior. This additionally may matter if alternative sampling strategies for equivalently distributed $P_{U_Y}$ (e.g. the Box-Meuller trick for sampling normals vs. pushing a uniform sample through its icdf) are more efficient. As we'll see below, we can sidestep these complexities by requiring that the user define a base distribution and transform chain, like in the `exogenated_model`s, but in such a way that we retain information required for likelihood computation, symbolic manipulations, and inversions to factual exogenous noise values. ## Implementation, Take 1 Spoiler alert: this doesn't work, but is still instructive. We'll start by requiring that user write distributions as explicit `TransformedDistributions`. For example, these are equivalent representations: ```python y_dist = TransformedDistribution( Normal(0., 1.), [AffineTransform(loc=x, scale=1.)] ) y_dist = TransformedDistribution( Uniform(0., 1.), [NormalICDFTransform(), AffineTransform(loc=x, scale=1.)] ) ``` Both represent $Y \sim \mathcal{N}(x, 1)$ with different choices of base exogenous noise. **What we don't want**: Base distributions that depend on intervened variables: ```python y_dist = TransformedDistribution( Normal(x, 1.), [] ) ``` This violates the exogeneity assumption—the base noise should be truly exogenous and independent of interventions. This should raise an error, and could be triggered, for example, by an `indices_of` check on base distribution parameters. For distributions like `Normal(loc=x, scale=1)` which can be written as `TransformedDistribution(Normal(0, 1), [AffineTransform(loc=x, scale=1)])`, we can exogenate by separating the sample into two sites—one for the exogenous noise, one for the endogenous value. Because we still have access to the full distribution (which has a closed form likelihood given inputs to transforms), we can easily recover that likelihood when needed: ```python def _pyro_sample(msg): if msg.get("is_observed", False): _pyro_observed_sample(msg) else: _pyro_unobserved_sample(msg) def _pyro_unobserved_sample(msg): fn = msg["fn"] # TransformedDistribution base_dist, transforms = unwrap_to_base(fn) # Sample noise u = pyro.sample(f"{name}_u", base_dist) # Apply transforms and create Delta site value = u for transform in transforms: value = transform(value) msg["value"] = pyro.sample(name, Delta(value)) msg["stop"] = True def _pyro_observed_sample(msg): fn = msg["fn"] # TransformedDistribution obs_value = msg["obs"] base_dist, transforms = unwrap_to_base(fn) # Infer noise by inverting transforms (as long as transforms are bijective — will revisit this later) u_implied = obs_value for transform in reversed(transforms): u_implied = transform.inv(u_implied) # Mask noise prior, add original likelihood with pyro.poutine.mask(mask=False): u = pyro.sample(f"{name}_u", base_dist, obs=u_implied) pyro.factor(f"{name}_log_prob", fn.log_prob(obs_value)) # Apply transforms and create Delta site value = u for transform in transforms: value = transform(value) msg["value"] = pyro.sample(name, Delta(value)) msg["stop"] = True ``` This preserves the original distribution's likelihood for observed sites while separating noise from the structural equation. ## The PyTorch Broadcasting Problem This approach fails due to PyTorch's distribution broadcasting semantics. PyTorch pre-broadcasts all distribution parameters to the maximum shape of the full transformed distribution **before** sampling. When `x` has shape `(2,)` due to intervention, the base distribution `Normal(0, 1)` gets expanded to shape `(2,)` as well, even though it should remain scalar. This reveals itself through `indices_of`: the base distribution's parameters (e.g. `loc` and `scale` for a normal) have `indices_of(u) == IndexSet({x: {0, 1}})`, indicating it is computationally downstream of the intervention, even though it should be upstream with respect to the desired causal structure. This highlights an important feature, in fact: that we do actually want independent noise across any batch dimension that doesn't map to an intervention. If we could sample the base noise `u` from a distribution that was expanded out to match all batch dimensions except those induced by intervention, then transforms involving intervened variables (like `x`) would broadcast as desired and in a way that shares noise. ## One Possible Solution: Squeeze Intervention Dimensions Sample from the pre-broadcast base distribution, then squeeze out intervention batch dimensions: ```python # Sample from base (pre-broadcast to all dims including intervention dims) u_broadcast = pyro.sample(f"{name}_u", base_dist) # Take first element along intervention dims, preserving other batch dims u_shared = squeeze_intervention_dims(u_broadcast) # Shape: removes intervention dims # Now when we apply transforms with intervened x, broadcasting happens correctly value = u_shared for transform in transforms: value = transform(value) # x broadcasts here, not in base_dist ``` Here, `squeeze_intervention_dims` identifies batch dimensions that arose from interventions (via `indices_of`) and slices to keep only the first element along those dimensions, effectively extracting the shared noise. ## Handling Observed Sites For observed sites with bijective transforms, we can analytically infer the noise: ```python if is_observed: # Invert transforms to recover noise from factual observation u_implied = factual_obs_value # sliced into factual world for transform in reversed(transforms): u_implied = transform.inv(u_implied) # Invert: obs -> noise # Because we already sliced into the factual world, u_implied will have unary # interventional dimensions. That is, it will have the same shape as u_shared # above. # The forward transforms will result in shared noise, as before. value = u_implied for transform in transforms: value = transform(value) # x broadcasts here, not in base_dist # TODO factual conditioning to add likelihoods only for the factual world. ``` This shared noise `u` then propagates through to counterfactual worlds when we apply the transforms with intervened values. Note that this does not create separate sample sites for noise at all. One issue: it's not entirely clear how to check whether the base distribution directly involves endogenous variables or if the involvement comes only from the transforms and pre-broadcasting the base distribution parameters. ## Limitations **Categorical and Other Non-Bijective Distributions**: For distributions where the transform chain is not bijective (e.g., categoricals), we cannot analytically invert to recover noise from observations. Solutions include: 1. **Soft conditioning**: Approximate hard conditioning with tight distributions (not very practical, and can probably be avoided) 2. **Hooks for Custom Implementations of Analytical Pre-Image**: For non-bijective transforms, compute the pre-image analytically and sample noise from the base distribution restricted to that subset. **This is Hacky**: We may instead be able to roll our own version of `TransformedDistribution` that exploits `indices_of` such that the base distribution parameter expansion only applies to non-interventional batch dimensions. A custom `TransformedDistribution` will allow us to solve both problems. ## Customizing `TransformedDistribution` (WIP) The key issue with PyTorch's `TransformedDistribution.__init__` (lines 79-98) is that it pre-broadcasts base distribution parameters to the maximum shape, including intervention dimensions. We need to expand the base distribution only along non-interventional batch dimensions. ### PyTorch's current behavior: ```python # In TransformedDistribution.__init__ base_batch_shape = expanded_base_shape[: len(expanded_base_shape) - base_event_dim] base_distribution = base_distribution.expand(base_batch_shape) # ← Expands to ALL dims ``` ### Sketch of Modification: TODO this would need to use pyro's wrapped `TransformedDistribution` ```python from chirho.indexed.ops import indices_of, get_index_plates, union from chirho.indexed.internals import _index_plate_dims # TODO does this break other methods of TransformedDistribution? class ExogenousTransformedDistribution(TransformedDistribution): def __init__(self, base_distribution, transforms, validate_args=None): # First, use the same code as PyTorch to determine what the full forward shape would be... ... # TODO does this also return pyro.plate dims that we do want to expand on? # Get active index plates (intervention dimensions) # Any dimension corresponding to an index plate should be excluded from base dist expansion name_to_dim = _index_plate_dims() # Returns {intervention_name: dim, ...} # Compute base_batch_shape with intervention dimensions set to 1 # (will broadcast via transform params that are downstream of intervention) base_batch_shape_non_interv = compute_shape_excluding_interventions( base_batch_shape, # The shape PyTorch would normally expand to name_to_dim ) # Expand base dist ONLY along non-intervention dimensions base_distribution = base_distribution.expand(base_batch_shape_non_interv) # Continue with standard TransformedDistribution initialization... ... # TODO Hopefully no need to override sample() or rsample() ``` ### Handling Observed Sites: Bijective Case For distributions with bijective transform chains, we can analytically invert to recover noise: ```python import pyro from pyro.distributions import Delta class ExogenousTransformedDistribution(TransformedDistribution): def sample_from_noise_conditional(self, obs_value): """ Sample exogenous noise conditional on an observed value. For bijective transforms: invert the transform chain analytically and sample from Delta(u). For non-bijective: must be overridden by subclasses (see below). Returns a sample from the conditional noise distribution. """ if not self.is_bijective(): raise NotImplementedError( f"{type(self).__name__} must override sample_from_noise_conditional() " "for non-bijective transforms" ) # Invert transforms: obs_value -> noise u = obs_value for transform in reversed(self.transforms): u = transform.inv(u) # Sample from Delta at the inverted noise value return pyro.sample("noise", Delta(u)) def is_bijective(self): """Check if all transforms in the chain are bijective.""" # TODO is .bijective actually a thing? return all(hasattr(t, 'inv') and t.bijective for t in self.transforms) ``` ### Handling Non-Bijective Transforms Ideally, each `Transform` would define a set-based inversion (for some abstract set), which could be composed through transform chains to automatically construct conditional noise samplers by renormalizing the base distribution to the pre-image set. In the near term, we can have specific implementations of `ExogenousTransformedDistribution` override `sample_from_noise_conditional()` to implement their own analytical pre-image sampling strategy (e.g., `GumbelMaxCategorical` for categorical distributions via Gumbel-max trick—see Appendix). ## Propagating Conditional, Factual Noise to Counterfactuals (WIP) TODO: figure this out, using as much existing factual conditioning machinery as possible ## Distributional Transformation Example: Excision (WIP) TODO # Appendix ## Example: Categorical with Gumbel-Max (AI Warning) The Gumbel-max trick represents a categorical distribution over K categories with probabilities $\mathbf{p} = (p_1, \ldots, p_K)$ as: ```python # Base noise: K independent Gumbel(0,1) samples u = pyro.sample("u", Gumbel(0, 1).expand([K])) # Structural equation: argmax of perturbed log probabilities logits = u + log(p) category = argmax(logits) ``` **Pre-image for observed category c**: If we observe `category = c`, we know: $$u_c + \log(p_c) > u_i + \log(p_i) \quad \text{for all } i \neq c$$ Equivalently: $u_c > u_i + \log(p_i/p_c)$ for all $i \neq c$. **Conditional sampling strategy**: ```python def _pyro_observed_sample_categorical(msg): obs_category = msg["obs"] # observed category index c log_probs = msg["fn"].logits # log(p_1), ..., log(p_K) K = len(log_probs) # Sample noise from conditional distribution given category = c # For i ≠ c: u_i ~ Gumbel(0,1) truncated above u_c + log(p_c/p_i) # For i = c: u_c ~ Gumbel(0,1) truncated below max_{i≠c}(u_i + log(p_i/p_c)) # Sample all u_i for i ≠ c unconditionally first u_others = pyro.sample("u_others", Gumbel(0, 1).expand([K-1])) # Compute lower bound for u_c log_prob_ratios = log_probs - log_probs[obs_category] # log(p_i/p_c) lower_bound = (u_others + log_prob_ratios[i≠c]).max() # Sample u_c from Gumbel truncated below lower_bound u_c = pyro.sample("u_c", TruncatedGumbel(0, 1, lower=lower_bound)) # Reconstruct full noise vector u = insert_at_index(u_others, u_c, obs_category) # Forward: recompute category (should equal obs_category) logits = u + log_probs category = argmax(logits) # Create Delta site with original likelihood with pyro.poutine.mask(mask=False): msg["value"] = pyro.sample(name, Delta(category)) pyro.factor(f"{name}_log_prob", Categorical(logits=log_probs).log_prob(obs_category)) ``` Even though the argmax operation is not bijective, we can analytically characterize the pre-image as a region in the noise space and sample from the conditional distribution of the noise given the observation. This shared noise then propagates correctly through counterfactual worlds. Naturally, we would need some way of generalizing the handler implementation in a way that could access these alternative implementations of noise pre-image sampling. ## A General View In essence, these transform chains are programs with symbolic identifiers (affording whatever symbolic manipulation may be desired, such as excising), closed form likelihoods, and closed form conditioned pre-images on noise. ## Notes 1. Custom transformed distribution class that has the desired semantics. 2. View this as a design experiment for future long term implementations. 3. Add a working example for excision (how does excision extract the symbolic information necessary) # Alternative Add a new operation ```python @effectful_op # or whatever this syntax is def rsample( base_dist: Dist, transforms: List[Callable], analytic_inv: Optional[Callable], analytic_likelihood: Optional[Callable] ): raise NotImplementedError() ``` ```python def _pyro_rsample(msg): # TODO if unobserved # sample from the base distribution, and propagate # add a log likelihood here for priors? # TODO if observed # compute a closed form likelihood the same # that transformed distribution does # if that's not possible, use analytic_likelihood # TODO if observed and in a counterfactual # also, if conditioning on a factual setting in # a counterfactual, infer the noise automatically # if possible (if transforms all have .inv def'd) # and then propagate that noise forward rather # that sampling independently # TODO see this for how transformed computes loglik # https://github.com/pytorch/pytorch/blob/v2.9.0/torch/distributions/transformed_distribution.py#L165 # TODO pass log likelihood to Delta sample site. # This will propagate to trace and expose the site # for intervention ... ``` Or ```python def _pyro_rsample(msg): base_dist = msg["kwargs"]["base_dist"] transforms = msg["kwargs"]["transforms"] analytic_inv = ... analytic_likelihood = ... # Immediately construct a sample site, so that # observation, etc, still works. This also # serve as the default unobserved case. pyro.sample( msg["name"], PlaceholderNotImplementedDist(), infer=dict(**msg["kwargs"], exog=True) ) msg["stop"] = True ``` And then handle everything with sample handlers. This is just syntactic sugar for embedding the information requried for exogenation.