owned this note
owned this note
Published
Linked with GitHub
# BP inference in Pyro
## Interface sketch v1
Consider something like `poutine.collapse()`.
```python=3
def model(x, y):
i_plate = pyro.plate("i", len(x), dim=-2)
j_plate = pyro.plate("j", len(y), dim=-1)
# Sample globals and compute parameters.
radius = pyro.sample("radius", LogNormal(0., 1.))
edge_logits = (-1/radius) * torch.cdist(x, y)
assert edge_logits.shape == (len(x), len(y))
# This transforms to a single pyro.factor statement.
# FIXME this misses the total scale.
with infer_combinatorial():
with i_plate, j_plate:
# FIXME edge_logits is multinomial, but we're treating it
# as Bernoulli logits.
edges = pyro.sample("edges", Bernoulli(logits=edge_logits))
with i_plate:
pyro.constrain_sum(edges, dim=-1, value=1)
with j_plate:
pyro.constrain_sum(edges, dim=-2, value=1)
# Now the guide should sample only global variables.
guide = AutoLowRankMultivariateNormal(model)
```
## Interface sketch v2
Consider a Pyro-oblivious PyTorch-based DSL for approximating partition functions. The results of these computations could be added to pyro.factor statements in models.
```python=3
def inner_model(edge_logits):
edges = dsl.variable(edge_logits.shape, type="binary")
dsl.factor(edges, edge_logits)
dsl.constrain_sum(edges, dim=-1, value=1)
dsl.constrain_sum(edges, dim=-2, value=1)
def model(x, y):
# ...sample globals and compute edge_logits...
problem = InferCombinatorial(inner_model)
problem.solve(edge_logits)
pyro.factor("inner", problem.log_partition_function)
guide = AutoNormal(model) # only sees globals.
```
## Interface sketch v3
It would be nice to be able to transition to guide-side amortized vi, as suggested by Mehrtash.
```python=3
def model(x, y):
# ...sample globals and compute edge_logits...
with i_plate, j_plate:
edges = pyro.sample("edges", Bernoulli(0.5).mask(False),
infer={"mean_field": "relax"})
problem = InferCombinatorial(inner_model)
problem.something(edge_logits)
pyro.factor("inner", problem.kl_divergence(edges))
def guide(x, y):
# Sample globals from a learned posterior.
loc = pyro.param("radius_loc", torch.tensor(0.))
scale = pyro.param("radius_scale", torch.tensor(1.),
constraint=constraints.positive)
pyro.sample("radius", LogNormal(loc, scale))
# Sample amortized mean field values. Note that mean_field
# is a nonstandard interpretation, similar to enumeration.
edge_marginals = my_fancy_nn(x, y, radius)
pyro.sample("edges", Delta(edge_marginals))
```
## Example phylogenetic model
```python=3
def model(leaf_sequences):
T = num_time_steps
M = num_mutations # approximately same as "strain"
def inner_model():
mutations = dsl.variable((T, M), type="bernoulli")
dsl.constrain_sum(mutations, dim=(-3, -2), value=1)
num_infected = dsl.variable((T, M), type="count")
dsl.factor(num_infected[0], init_state,
type="negative_binomial")
dsl.factor(num_infected[1:],
transition(num_infected[:-1], mutations[1:]),
type="negative_binomial")
# TODO dsl.factor("aggregate epi info")
problem = InferCombinatorial(inner_model)
problem.solve()
pyro.factor("inner", problem.log_partition_function)
# Condition on case counts.
pyro.sample("case_data",
# marginalize over strains
Poission(num_infected.sum(-1)),
obs=case_data)
# Condition on genetic sequences.
with pyro.plate("leaves", len(leaf_data["seq"]),
subsample_size=100) as subsample:
leaf_seq = leaf_data["seq"][subsample]
t = leaf_data["date"][subsample]
with pyro.plate("mutations", M):
mean_infected = problem.marginals["num_infected"]
assert mean_infected.shape == (T, M)
probs = mean_infected / mean_infected.sum(-1, True)
pyro.sample("leaf_obs", Categorical(probs[t]),
obs=seq)
return problem.marginals # returns a dict
```
## Example phylogeographic model
```python=3
def model(leaf_data, case_data, mobile_data, flight_data):
T = num_time_steps
R = num_regions
M = num_mutations
# Globals
flux = pyro.sample("flux", TransportationDistribution(...))
assert flux.shape == (T, R) # no mutation information
pyro.sample("mobile_phones", dist.Poisson(f1(flux)),
obs=mobile_data)
pyro.sample("air_travel", dist.Poisson(f2(flux)),
obs=flight_data)
# Locals
def inner_model():
mutations = dsl.variable((T, R, M), type="bernoulli")
dsl.constrain_sum(mutations, dim=(-3, -2), value=1)
num_infected = dsl.variable((T, R, M), type="count")
dsl.factor(num_infected[0], init_state,
type="negative_binomial")
dsl.factor(num_infected[1:]
transition(flux, num_infected[:-1], mutations[1:]),
type="negative_binomial") # ???
dsl.factor("TODO aggregate epi info")
problem = InferCombinatorial(inner_model)
problem.solve()
pyro.factor("inner", problem.log_partition_function)
with pyro.plate("leaves", len(leaf_data["seq"]),
subsample_size=100) as subsample:
leaf_seq = leaf_data["seq"][subsample]
t = leaf_data["date"][subsample]
r = leaf_data["region"][subsample]
with pyro.plate("mutations", M):
mean_infected = problem.marginals["num_infected"]
assert mean_infected.shape == (T, R, M)
probs = mean_infected / mean_infected.sum(-1, True)
pyro.sample("leaf_obs", Categorical(probs[t, r]),
obs=seq)
return problem.marginals # returns a dict
```