This package provides high-performance, GPU- and AD-friendly monotonic spline functions in Julia for use in Normalizing Flows, resp. parameter transformations in general.
This package currently includes the monotonic rational quadratic splines defined in "Neural Spline Flows, Durkan et al. 2019".
Please see the Documentation linked below for details.
using MonotonicSplines, Plots, InverseFunctions, ChangesOfVariables
f = rand(RQSpline)
f.pX, f.pY, f.dYdX
plot(f, xlims = (-6, 6)); plot!(inverse(f), xlims = (-6, 6))
x = 1.2
y = f(x)
with_logabsdet_jacobian(f, x)
inverse(f)(y)
with_logabsdet_jacobian(inverse(f), y)
Given a set
Here
Consider a single sample
In the context of Normalizing Flows, the set of parameters
This neural net takes the first
These "raw" spline parameters are then processed as described in "Neural Spline Flows, Durkan et al. 2019" to obtain
MonotonicSplines.jl
is designed with parallelism in mind, and this implementation allows for the simultaneous transformation of batches of samples using spline functions.
To this end, the parameters for characterizing sets of several spline functions are stored in the same struct.
Now consider the task of transforming the entire set of samples
To achieve this, each of the
So the spline function
Given the output params_raw
of the neural net pX
, pY
, and dYdX
to characterize the desired spline function as follows:
julia> pX, pY, dYdX = rqs_params_from_nn(params_raw, n_dims_to_transform)
Here, params_raw
is a 3(K-1) * n_dims_to_transform x n_samples
-matrix. K
again is the number of spline segments and n_dims_to_transform
The i
-th column of this matrix params_raw
is the output of the neural net i
-th sample
pX
, pY
, and dYdX
each are K x n_dims_to_transform x n_samples
-arrays. The [:,j,i]
entries hold the parameters to characterize j
-th component of the i
-th sample from the sample set.
We then define the set of spline functions by:
julia> rqs_splines = RQSpline(pX, pY, dYdX)
An object holding the parameters to characterize n_dims_to_transform x n_samples
spline functions.
To apply the spline functions characterized by the parameters stored in rqs_splines
, we first isolate the components of the sample set that are supposed to be transformed and then do:
julia> Y_partial = rqs_splines(X_partial)
X_partial
is a n_dims_to_transform x n_samples
-matrix, holding the components of the sample set i
-th column in X_partial
holds the d
-th to D
-th elements of the i
-th sample in
Y_partial
is a n_dims_to_transform x n_samples
matrix, where the i,j
-th component is the transformed value of the i,j
-th entry in X_partial
.
For further details on the implementation, see the Documentation for stable version.