LaplaceRedux.jl
is a library written in pure Julia that can be used for effortless Bayesian Deep Learning through Laplace Approximation (LA). In the development of this package I have drawn inspiration from this Python library and its companion paper (Daxberger et al. 2021).
The stable version of this package can be installed as follows:
using Pkg
Pkg.add("LaplaceRedux.jl")
The development version can be installed like so:
using Pkg
Pkg.add("https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl")
If you are new to Deep Learning in Julia or simply prefer learning through videos, check out this awesome YouTube tutorial by doggo.jl ๐ถ. Additionally, you can also find a video of my presentation at JuliaCon 2022 on YouTube.
LaplaceRedux.jl
can be used for any neural network trained in Flux.jl
. Below we show basic usage examples involving two simple models for a regression and a classification task, respectively.
A complete worked example for a regression model can be found in the docs. Here we jump straight to Laplace Approximation and take the pre-trained model nn
as given. Then LA can be implemented as follows, where we specify the model likelihood
. The plot shows the fitted values overlaid with a 95% confidence interval. As expected, predictive uncertainty quickly increases in areas that are not populated by any training data.
la = Laplace(nn; likelihood=:regression)
fit!(la, data)
optimize_prior!(la)
plot(la, X, y; zoom=-5, size=(500,500))
Once again we jump straight to LA and refer to the docs for a complete worked example involving binary classification. In this case we need to specify likelihood=:classification
. The plot below shows the resulting posterior predictive distributions as contours in the two-dimensional feature space: note how the Plugin Approximation on the left compares to the Laplace Approximation on the right.
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
la_untuned = deepcopy(la) # saving for plotting
optimize_prior!(la; n_steps=100)
# Plot the posterior predictive distribution:
zoom=0
p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1))
p_untuned = plot(la_untuned, X, ys; title="LA - raw (ฮป=$(unique(diag(la_untuned.prior.Pโ))[1]))", clim=(0,1), zoom=zoom)
p_laplace = plot(la, X, ys; title="LA - tuned (ฮป=$(round(unique(diag(la.prior.Pโ))[1],digits=2)))", clim=(0,1), zoom=zoom)
plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))
This project was presented at JuliaCon 2022 in July 2022. See here for details.
Contributions are very much welcome! Please follow the SciML ColPrac guide. You may want to start by having a look at any open issues.
Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2021. โLaplace Redux-Effortless Bayesian Deep Learning.โ Advances in Neural Information Processing Systems 34.