
Subspace Inference package for uncertainty analysis in deep neural networks and neural ordinary differential equations using Julia
Author efmanu
5 Stars
Updated Last
2 Years Ago
Started In
January 2021

Subspace Inference for Bayesian Deep Learning

This package aims to generate the subspace and subspace inferences.

This work is implemented by referring the folloing publication:

Izmailov, P., Maddox, W. J., Kirichenko, P., Garipov, T., Vetrov, D., & Wilson, A. G. (2020, August). Subspace inference for Bayesian deep learning. In Uncertainty in Artificial Intelligence (pp. 1169-1179). PMLR.

Subspace Inference

To generate the uncertainty in machine learing models using subspace inference method

subspace_inference(model, cost, data, opt; callback =()->(return 0),
	σ_z = 1.0,	σ_m = 1.0, σ_p = 1.0,
	itr =1000, T=25, c=1, M=20, print_freq=1)

Input Arguments

  • model : Machine learning model. Eg: Chain(Dense(10,2)). Model should be created with Chain in Flux
  • cost : Cost function. Eg: L(x, y) = Flux.Losses.mse(m(x), y)
  • data : Inputs and outputs. Eg: X = rand(10,100); Y = rand(2,100); data = DataLoader(X,Y);
  • opt : Optimzer. Eg: opt = ADAM(0.1)

Keyword Arguments

  • callback : Callback function during training. Eg: callback() = @show(L(X,Y))
  • σ_z : Standard deviation of subspace
  • σ_m : Standard deviation of likelihood model
  • σ_p : Standard deviation of prior
  • itr : Iterations for sampling
  • T : Number of steps for subspace calculation. Eg: T= 1
  • c : Moment update frequency. Eg: c = 1
  • M : Maximum number of columns in deviation matrix. Eg: M= 3


  • chn : Chain with samples with uncertainty
  • lp : Log probabilities of all samples
  • W_swa : Mean Weight
  • re : Model reformatting functioninformations


using SubspaceInference
using Flux
using Flux: @epochs
using Flux: Flux.Data.DataLoader

l_m = 10
l_n = 100
O = 2

X = rand(l_m,l_n) #input
Y = rand(O,l_n) #output 

data =  DataLoader(X,Y, shuffle=true)

m = Chain(Dense(l_m, 20), Dense(20, 20), Dense(20, O)) #model

L(x, y) = Flux.Losses.mse(m(x), y) #cost function

ps = Flux.params(m) #model parameters

opt = ADAM(0.1) #optimizer

callback() = @show(L(X,Y)) #callback function

@epochs 1 Flux.train!(L, ps, data, opt, cb = () -> callback()) #training

M = 3
T = 10
c= 1
itr = 10
L1(m, x, y) = Flux.Losses.mse(m(x), y) #cost function
chn, lp, W_swa = subspace_inference(m, L1, data, opt, itr = itr, T=T, c=1, M=M)

Subspace Construction

If you just want to generate subspace, you can use subspace_construction function.

The subspace can be generated by using the following function:

	subspace_construction(model, cost, data, opt; 
		callback = ()->(return 0), T = 10, c = 1, M = 3, 
		LR_init = 0.01, print_freq = 1

Input Arguments

  • model : Machine learning model. Eg: Chain(Dense(10,2)). Model should be created with Chain in Flux
  • cost : Cost function. Eg: L(x, y) = Flux.Losses.mse(m(x), y)
  • data : Inputs and outputs. Eg: X = rand(10,100); Y = rand(2,100); data = DataLoader(X,Y);
  • opt : Optimzer. Eg: opt = ADAM(0.1)

Keyword Arguments

  • callback : Callback function during training. Eg: callback() = @show(L(X,Y))
  • T : Number of steps for subspace calculation. Eg: T= 1
  • c : Moment update frequency. Eg: c = 1
  • M : Maximum number of columns in deviation matrix. Eg: M= 2
  • LR_init : Initial learning rate cyclic learning rate updation
  • print_freq: Loss printing frequency


  • W_swa : Mean weights
  • P : Projection Matrix
  • re : Model reconstruction function


using SubspaceInference
using Flux
using Flux: @epochs
using Flux: Flux.Data.DataLoader

l_m = 10
l_n = 100
O = 2

X = rand(l_m,l_n) #input
Y = rand(O,l_n) #output 

data =  DataLoader(X,Y, shuffle=true)

m = Chain(Dense(l_m, 20), Dense(20, 20), Dense(20, O)) #model

L(x, y) = Flux.Losses.mse(m(x), y) #cost function

ps = Flux.params(m) #model parameters

opt = ADAM(0.1) #optimizer

callback() = @show(L(X,Y)) #callback function

@epochs 1 Flux.train!(L, ps, data, opt, cb = () -> callback()) #training

M = 3
T = 10
c= 1
L(m, x, y) = Flux.Losses.mse(m(x), y) #cost function
W_swa, P = subspace_construction(m, L, data, opt, T = T, c = c, M = M)


chn, lp = SubspaceInference.inference(m, data, W_swa, P; σ_z = 1.0,
	σ_m = 1.0, σ_p = 1.0, itr=100, M = 3, alg = :mh)