ContinuousNormalizingFlows.jl

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia
Author impICNF
Popularity
22 Stars
Updated Last
23 Days Ago
Started In
November 2021

ContinuousNormalizingFlows.jl

Stable Dev Version Deps Build Status Codecov Coverage Coveralls Coverage PkgEval Pkg Eval Monthly Downloads Total Downloads Aqua JET ColPrac: Contributor's Guide on Collaborative Practices for Community Packages

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia

Citing

See CITATION.bib for the relevant reference(s).

Installation

using Pkg
Pkg.add("ContinuousNormalizingFlows")

Usage

# Enable Logging
using Logging, TerminalLoggers
global_logger(TerminalLogger())

# Parameters
nvars = 1
naugs = nvars
# n_in = nvars # without augmentation
n_in = nvars + naugs # with augmentation
n = 1024

# Model
using ContinuousNormalizingFlows, Lux, ADTypes, Enzyme #, CUDA, ComputationalResources
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
# icnf = construct(RNODE, nn, nvars) # use defaults
icnf = construct(
    RNODE,
    nn,
    nvars, # number of variables
    naugs; # number of augmented dimensions
    compute_mode = DIJacVecMatrixMode(AutoEnzyme(; function_annotation = Enzyme.Const)), # process data in batches
    tspan = (0.0f0, 13.0f0), # have bigger time span
    steer_rate = 1.0f-1, # add random noise to end of the time span
    # resource = CUDALibs(), # process data by GPU
    # inplace = true, # use the inplace version of functions
)

# Data
using Distributions
data_dist = Beta{Float32}(2.0f0, 4.0f0)
r = rand(data_dist, nvars, n)
r = convert.(Float32, r)

# Fit It
using DataFrames, MLJBase #, ForwardDiff, ADTypes, OptimizationOptimisers
df = DataFrame(transpose(r), :auto)
# model = ICNFModel(icnf) # use defaults
model = ICNFModel(
    icnf;
    batch_size = 256, # have bigger batchs
    # n_epochs = 100, # have less epochs
    # optimizers = (Adam(),), # use a different optimizer
    # adtype = AutoForwardDiff(), # use ForwardDiff
)
mach = machine(model, df)
fit!(mach)
ps, st = fitted_params(mach)

# Store It
using JLD2, UnPack
jldsave("fitted.jld2"; ps, st) # save
@unpack ps, st = load("fitted.jld2") # load

# Use It
d = ICNFDist(icnf, TestMode(), ps, st) # direct way
# d = ICNFDist(mach, TestMode()) # alternative way
actual_pdf = pdf.(data_dist, vec(r))
estimated_pdf = pdf(d, r)
new_data = rand(d, n)

# Evaluate It
using Distances
mad_ = meanad(estimated_pdf, actual_pdf)
msd_ = msd(estimated_pdf, actual_pdf)
tv_dis = totalvariation(estimated_pdf, actual_pdf) / n
res_df = DataFrame(; mad_, msd_, tv_dis)
display(res_df)

# Plot It
using CairoMakie
f = Figure()
ax = Makie.Axis(f[1, 1]; title = "Result")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "actual")
lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "estimated")
axislegend(ax)
save("result-fig.svg", f)
save("result-fig.png", f)