GenFlux.jl
is Gen DSL which implements the generative function interface to allow the usage of Flux.jl models as Gen generative functions.
g = @genflux Chain(Conv((5, 5), 1 => 10; init = glorot_uniform64),
MaxPool((2, 2)),
x -> relu.(x),
Conv((5, 5), 10 => 20; init = glorot_uniform64),
x -> relu.(x),
MaxPool((2, 2)),
x -> flatten(x),
Dense(320, 50; initW = glorot_uniform64),
Dense(50, 10; initW = glorot_uniform64),
softmax)
Now you can use g
as a modelling component in your probabilistic programs:
@gen function f(xs::Vector{Float64})
probs ~ g(xs)
[{:y => i} ~ categorical(p |> collect) for (i, p) in enumerate(eachcol(probs))]
end
Allowing you to train the parameters of g
via gradient descent on the objective:
update = ParamUpdate(Flux.ADAM(5e-5, (0.9, 0.999)), g)
for i = 1 : 1500
# Create trace from data
(xs, ys) = next_batch(loader, 100)
constraints = choicemap([(:y => i) => y for (i, y) in enumerate(ys)]...)
(trace, weight) = generate(f, (xs,), constraints)
# Increment gradient accumulators
accumulate_param_gradients!(trace)
# Perform ADAM update and then resets gradient accumulators
apply!(update)
println("i: $i, weight: $weight")
end
test_accuracy = mean(f(test_x) .== test_y)
println("Test set accuracy: $test_accuracy")
# Test set accuracy: 0.9392