Final evaluation

Work that is merged

PR: pymc4/#306
Issues: pymc4/#187, pymc4/#171

The main goal of the work during the summer was to provide the support for various samplers and sampling of discrete variables. We have written and pushed and an inteface that could easily expanded for various number of sampling algorithms.

First, we have provided the base class for all the pymc4 sampling algorithms:

class _BaseSampler(metaclass=abc.ABCMeta):
    def _sample(self, ...):
        ...

    @abc.abstractmethod
    def trace_fn(self, ...):
        """
        Support a tracing for each sampler
        """
        pass

The class has an abstract trace_fn method which defines the tracing logic for each subclass of sampling algorithm. Example of NUTS subclass:

class NUTS(_BaseSampler):
    _name = "nuts"
    _grad = True
    _adaptation = mcmc.DualAveragingStepSizeAdaptation
    _kernel = mcmc.NoUTurnSampler

    def trace_fn(self, current_state, pkr):
        return (
            pkr.inner_results.target_log_prob,
            pkr.inner_results.leapfrogs_taken,
            pkr.inner_results.has_divergence,
            pkr.inner_results.energy,
            pkr.inner_results.log_accept_ratio,
        ) + tuple(self.deterministics_callback(*current_state))

We see the _grad to identify sampling algorithms that calculate gradient in the step method. Also we easily support tensorflow_probability logic by also including policy algorithm. For base nuts the adaptation policy is mcmc.DualAveragingStepSizeAdaptation. We can easily provide another policy by changing the appropriate class attribute. Tracing method can be also easily modified by subclassig the class.

The next thing was to support compound step. For that we can easily implement another sampling algorith class:

class CompoundStep(_BaseSampler):
    _name = "compound"
    _adaptation = None
    _kernel = _CompoundStepTF
    _grad = False
    
    def trace_fn(self, current_state, pkr):
        ...

    def _assign_default_methods(
        self,
        ...
    ):
    """
    Assign the appropriate sampling algorithm for each
    variable and merge equal samplers
    """    

Then we need to implement the sub-class of tfp...TransitionKernel to provide a compound logic of one_step:

class _CompoundGibbsStepTF(kernel_base.TransitionKernel):
    def one_step(self, current_state, previous_kernel_results, seed=None):
        ...
    def bootstrap_results(self, init_state):
        """
        Returns an object with the same type as returned by `one_step(...)[1]`
        Compound bootrstrap step
        """
       ...

Also we have provided a support for custom proposal generation functions:

class Proposal(metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]:
        """
        Proposal function that is passed as the argument
        to RWM kernel
        """
        pass

    @abc.abstractmethod
    def __eq__(self, other) -> bool:
        """
        Comparison of the proposal func
        """
        pass


class CategoricalUniformFn(Proposal):
    """
    Categorical proposal sub-class with the `_fn` that is sampling new proposal
    from catecorical distribution with uniform probabilities.
    """

    _name = "categorical_uniform_fn"

    def _fn(self, state_parts: List[tf.Tensor], seed: Optional[int]) -> List[tf.Tensor]:
        with tf.name_scope(self._name or "categorical_uniform_fn"):
            part_seeds = samplers.split_seed(seed, n=len(state_parts), salt="CategoricalUniformFn")
            deltas = tf.nest.map_structure(
                lambda x, s: tfd.Categorical(logits=tf.ones(self.classes)).sample(
                    seed=s, sample_shape=tf.shape(x)
                ),
                state_parts,
                part_seeds,
            )
            return deltas

    def __eq__(self, other) -> bool:
        return self._name == other._name and self.classes == other.classes

Compound sampling algorithm accepts the list of variables and optionally the list of sampling algorithms (if not then the appropriate sampling logic is chosen), and assigns the appropriate proposal generation functions for each distribution.

Work in progress

PR: pymc4/#287
Support Sequential Monte Carlo (SMC). Thanks to @junpenglao, for providing the tfp implementation of SMC algorithm. Due to this, providing the support in pymc4 was much easier. Our job was to just provide sample_smc logic with modified logp functions, which could separately calculate probabilities for prior and likelihood.
This is implemented in separate SamplingState sub-class:

class SMCSamplingState(SamplingState):
    """
    Subclass of `SamplingState` which adds the support of
    log probability collection separately for likelihood and
    prior.
    """

    __slots__ = ()

    def collect_log_prob_smc(self, is_prior):
        """
        Collects log probabilities for likelihood variables in sMC.
        Since sMC requires the `draws` dimension to be kept explicitly
        while the graph is evaluated, we can't combine sMC prbability
        collection with the NUTS log probability collection.
        """
        ...

Execution for SMC is also separated from the main execution logic to provide distinct logic _sample_unobserved function.

Ideas in proposal:

  • Support for more samplers, tricky because of multimodal/discrete distribution
  • Support for sampler method assigner as in pymc3 (Not quite like in pymc3 but shouldn't be)
  • Supporting optimized samplers for various samplers, i.e. BinaryMetropolis and etc.
    This on is not required. The logic of pymc3 BinaryMetropolis is provided by using the proposal function for (e.g.) Bernoulli.
  • Progress bar for samplers, includes some hacking on tfp side too.
  • Add support for SMC (Work stil not merged)
  • Fix all the issues with discrete distribution sampling, design more user friendly interface. Additionally, fix issues with xla .
  • Support for step methods CompoundStep and Gibbs . But should be discussed if there is a need for that with current design.