This package implements:
in the Julia programming language. The package supports both forward-mode automatic differentiation (AD) and reverse-mode AD through Zygote and ForwardDiff respectively.
The easiest way is to get the package directly from the Julia repository
using Pkg
Pkg.add("GumbelSoftmax")
The expected input shape is (latent_dimension, categorical_dimension, batch_dimension)
. As an example, let's suppose we have 4 Categorical distributions with 3 classes and we want to sample 10 times. In this case, latent_dimension=4
, categorical_dimension=3
, and batch_dimension=10
.
using GumbelSoftmax, Random
logits = randn(4, 3, 10)
samples = sample_gumbel_softmax(logits=logits, tau=0.1, hard=true)
# or with Rao-Blackwellization
k = 10 # number of Monte-Carlo samples
samples = sample_rao_gumbel_softmax(logits=logits, tau=0.1, k=k)
We include an example of using the Gumbel-Softmax trick to implement a discrete VAE.
The example can be found in examples/vae.jl
and it can be run with
julia examples/vae.jl
Here are some results:
VAE loss:
VAE reconstructions and generated samples: