GenTraceKernelDSL.jl

A DSL for defining stochastic maps between traces of Gen generative functions
Author probcomp
Popularity
5 Stars
Updated Last
7 Months Ago
Started In
January 2021

GenTraceKernelDSL.jl

This package provides a DSL for constructing trace kernels, stochastic maps between the traces of Gen generative functions, for use as (generalized) Metropolis-Hastings or proposal distributions in sequential Monte Carlo. (Specifically, trace kernels are used to define a a type of sequential Monte Carlo algorithm called an SMCP3 algorithm.)

This package can be viewed as a refactoring of Gen's Trace Translator functionality, described in Marco Cusumano-Towner's thesis and also in the arXiv preprint Automating Involutive MCMC using Probabilistic and Differentiable Programming. Unlike Gen's trace transform DSL, this package does not enforce separation of the "probablistic" and "differentiable" components of a trace translator: users may freely mix sampling with deterministic transformations to describe arbitrary stochastic transformations.

DSL

A kernel function is declared using the @kernel macro. The kernel's body may contain deterministic Julia code, as well as ~ expressions, familiar from Gen:

  • {:x} ~ dist(args) samples from a Gen distribution at address :x
  • {:x} ~ gen_fn(args) samples from a Gen generative function at address :x, and evaluates to the trace of the function, rather than its return value
  • {:x} ~ kernel_fn(args) calls another @kernel-defined function at address :x, and evaluates to its return value.

As in Gen, x = {:x} ~ f() can be shortened to x ~ f(), and—for generative function or kernel calls—the {*} ~ f() syntax can be used to splice the choices made by f into the "top level" of the caller's choicemap.

Kernels intended for use as MH proposals should accept a current trace as their first argument, and return a Tuple of: (1) a ChoiceMap of proposed values to update in the trace, and (2) a ChoiceMap specifying a reverse move.

Kernels inteded for use as SMC proposals should be written in pairs: a forward and backward kernel. The forward (backward) kernel should accept a previous (subsequent) model trace as its first argument, and return a Tuple containing: (1) a ChoiceMap specifying a proposed next (previous) model state, and (2) a ChoiceMap of the backward (forward) kernel that would recover the previous (subsequent) model state. See GenSMCP3 for details of how to use kernel DSL proposals within SMC, and for inter-operation between the kernel DSL and Gen's particle filtering library.

For example, here is what Gen's example split-merge proposal looks like written in the DSL:

@kernel function split_merge_proposal(trace)
    if trace[:z]
        # Currently two means, switch to one
        m, u  = merge_mean(trace[:m1], trace[:m2])
        return choicemap(:z => false, :m => m), choicemap(:u => u)
    else
        # Currently one mean, switch to two
        u ~ uniform_continuous(0, 1)
        m1, m2 = split_mean(trace[:m], u)
        return choicemap(:z => true, :m1 => m1, :m2 => m2), choicemap()
    end
end

Kernels can be passed to Gen.metropolis_hastings(trace, proposal) for use in Metropolis Hastings MCMC, or used with GenSMCP3 for use in sequential Monte Carlo.

See example.jl for a full example.