
Author probcomp
3 Stars
Updated Last
3 Years Ago
Started In
December 2020


Build Status Link to Documentation

GenFlux.jl is Gen DSL which implements the generative function interface to allow the usage of Flux.jl models as Gen generative functions.

(full example available here)

g = @genflux Chain(Conv((5, 5), 1 => 10; init = glorot_uniform64),
                   MaxPool((2, 2)),
                   x -> relu.(x),
                   Conv((5, 5), 10 => 20; init = glorot_uniform64),
                   x -> relu.(x),
                   MaxPool((2, 2)),
                   x -> flatten(x),
                   Dense(320, 50; initW = glorot_uniform64),
                   Dense(50, 10; initW = glorot_uniform64),

Now you can use g as a modelling component in your probabilistic programs:

@gen function f(xs::Vector{Float64})
    probs ~ g(xs)
    [{:y => i} ~ categorical(p |> collect) for (i, p) in enumerate(eachcol(probs))]

Allowing you to train the parameters of g via gradient descent on the objective:

update = ParamUpdate(Flux.ADAM(5e-5, (0.9, 0.999)), g)
for i = 1 : 1500
    # Create trace from data
    (xs, ys) = next_batch(loader, 100)
    constraints = choicemap([(:y => i) => y for (i, y) in enumerate(ys)]...)
    (trace, weight) = generate(f, (xs,), constraints)

    # Increment gradient accumulators

    # Perform ADAM update and then resets gradient accumulators
    println("i: $i, weight: $weight")
test_accuracy = mean(f(test_x) .== test_y)
println("Test set accuracy: $test_accuracy")
# Test set accuracy: 0.9392

Used By Packages

No packages found.