PR: pymc4/#306
Issues: pymc4/#187, pymc4/#171
The main goal of the work during the summer was to provide the support for various samplers and sampling of discrete variables. We have written and pushed and an inteface that could easily expanded for various number of sampling algorithms.
First, we have provided the base class for all the pymc4
sampling algorithms:
class _BaseSampler(metaclass=abc.ABCMeta):
def _sample(self, ...):
...
@abc.abstractmethod
def trace_fn(self, ...):
"""
Support a tracing for each sampler
"""
pass
The class has an abstract trace_fn
method which defines the tracing logic for each subclass of sampling algorithm. Example of NUTS
subclass:
class NUTS(_BaseSampler):
_name = "nuts"
_grad = True
_adaptation = mcmc.DualAveragingStepSizeAdaptation
_kernel = mcmc.NoUTurnSampler
def trace_fn(self, current_state, pkr):
return (
pkr.inner_results.target_log_prob,
pkr.inner_results.leapfrogs_taken,
pkr.inner_results.has_divergence,
pkr.inner_results.energy,
pkr.inner_results.log_accept_ratio,
) + tuple(self.deterministics_callback(*current_state))
We see the _grad
to identify sampling algorithms that calculate gradient in the step method. Also we easily support tensorflow_probability
logic by also including policy algorithm. For base nuts the adaptation policy is mcmc.DualAveragingStepSizeAdaptation
. We can easily provide another policy by changing the appropriate class attribute. Tracing method can be also easily modified by subclassig the class.
The next thing was to support compound step. For that we can easily implement another sampling algorith class:
class CompoundStep(_BaseSampler):
_name = "compound"
_adaptation = None
_kernel = _CompoundStepTF
_grad = False
def trace_fn(self, current_state, pkr):
...
def _assign_default_methods(
self,
...
):
"""
Assign the appropriate sampling algorithm for each
variable and merge equal samplers
"""
Then we need to implement the sub-class of tfp...TransitionKernel
to provide a compound logic of one_step
:
class _CompoundGibbsStepTF(kernel_base.TransitionKernel):
def one_step(self, current_state, previous_kernel_results, seed=None):
...
def bootstrap_results(self, init_state):
"""
Returns an object with the same type as returned by `one_step(...)[1]`
Compound bootrstrap step
"""
...
Also we have provided a support for custom proposal generation functions:
class Proposal(metaclass=abc.ABCMeta):
@abc.abstractmethod
def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]:
"""
Proposal function that is passed as the argument
to RWM kernel
"""
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
"""
Comparison of the proposal func
"""
pass
class CategoricalUniformFn(Proposal):
"""
Categorical proposal sub-class with the `_fn` that is sampling new proposal
from catecorical distribution with uniform probabilities.
"""
_name = "categorical_uniform_fn"
def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]:
with tf.name_scope(self._name or "categorical_uniform_fn"):
part_seeds = samplers.split_seed(seed, n=len(state_parts), salt="CategoricalUniformFn")
deltas = tf.nest.map_structure(
lambda x, s: tfd.Categorical(logits=tf.ones(self.classes)).sample(
seed=s, sample_shape=tf.shape(x)
),
state_parts,
part_seeds,
)
return deltas
def __eq__(self, other) -> bool:
return self._name == other._name and self.classes == other.classes
Compound sampling algorithm accepts the list of variables and optionally the list of sampling algorithms (if not then the appropriate sampling logic is chosen), and assigns the appropriate proposal generation functions for each distribution.
PR: pymc4/#287
Support Sequential Monte Carlo (SMC). Thanks to @junpenglao, for providing the tfp
implementation of SMC algorithm. Due to this, providing the support in pymc4
was much easier. Our job was to just provide sample_smc
logic with modified logp
functions, which could separately calculate probabilities for prior and likelihood.
This is implemented in separate SamplingState
sub-class:
class SMCSamplingState(SamplingState):
"""
Subclass of `SamplingState` which adds the support of
log probability collection separately for likelihood and
prior.
"""
__slots__ = ()
def collect_log_prob_smc(self, is_prior):
"""
Collects log probabilities for likelihood variables in sMC.
Since sMC requires the `draws` dimension to be kept explicitly
while the graph is evaluated, we can't combine sMC prbability
collection with the NUTS log probability collection.
"""
...
Execution for SMC
is also separated from the main execution logic to provide distinct logic _sample_unobserved
function.
pymc3
(Not quite like in pymc3
but shouldn't be)pymc3
BinaryMetropolis
is provided by using the proposal function for (e.g.) Bernoulli
.