GenFlux.jl

Author probcomp
Popularity
3 Stars
Updated Last
2 Years Ago
Started In
December 2020

GenFlux.jl

Build Status Link to Documentation

GenFlux.jl is Gen DSL which implements the generative function interface to allow the usage of Flux.jl models as Gen generative functions.


(full example available here)

g = @genflux Chain(Conv((5, 5), 1 => 10; init = glorot_uniform64),
                   MaxPool((2, 2)),
                   x -> relu.(x),
                   Conv((5, 5), 10 => 20; init = glorot_uniform64),
                   x -> relu.(x),
                   MaxPool((2, 2)),
                   x -> flatten(x),
                   Dense(320, 50; initW = glorot_uniform64),
                   Dense(50, 10; initW = glorot_uniform64),
                   softmax)

Now you can use g as a modelling component in your probabilistic programs:

@gen function f(xs::Vector{Float64})
    probs ~ g(xs)
    [{:y => i} ~ categorical(p |> collect) for (i, p) in enumerate(eachcol(probs))]
end

Allowing you to train the parameters of g via gradient descent on the objective:

update = ParamUpdate(Flux.ADAM(5e-5, (0.9, 0.999)), g)
for i = 1 : 1500
    # Create trace from data
    (xs, ys) = next_batch(loader, 100)
    constraints = choicemap([(:y => i) => y for (i, y) in enumerate(ys)]...)
    (trace, weight) = generate(f, (xs,), constraints)

    # Increment gradient accumulators
    accumulate_param_gradients!(trace)

    # Perform ADAM update and then resets gradient accumulators
    apply!(update)
    println("i: $i, weight: $weight")
end
test_accuracy = mean(f(test_x) .== test_y)
println("Test set accuracy: $test_accuracy")
# Test set accuracy: 0.9392

Used By Packages

No packages found.