🚜 Documentation is still WIP. Feel free to open a PR if you have any quetions. 🚧
TDAOpt.jl
is a Julia package to optimize statistical and topological loss functions defined on point-clouds and functions defined on fixed grids.
From the Julia REPL, type ]
to enter the Pkg REPL mode and run
pkg> add https://github.com/sidv23/TDAOpt.jl
There are 3 (and a haf) main functions exported by the current API.
dist()
: computes a distance between two point-clouds/measuresval()
: evaluates a loss functional on a source point-cloud/measurebackprop()
: performs minimization of loss functional starting from a source point-cloud/measure∇w
: computes the Wasserstein gradient of a loss functional (used inbackprop
when specified)
These functions belong to the following parts of the current API:
A discrepancy is a struct
which configures the parameters for measuring a distance between two input matrices (x
and y
), or two input measures (μ
and ν
). The ones currently implemented are:
- Statistical (MMD, Sinkhorn)
- Topological (Wasserstein Matching)
Every discrepancy should dispatch on (i.e. extend) the function dist
Losses are structs
which define methods of computing loss functionals. The abstract supertype is AbstractLossFunction
. Every loss is expected to extend the function val
which takes in a source
(i.e. matrix x
or measure μ
) and an AbstractLossFunction
, and evaluates the loss functional
Currently, the implemented AbstractLossFunctions
fall into the following categories:
-
StatLoss
: Fixes a reference discrepancyd
andtarget
(matrix or measure)val(source, Loss) = dist(Loss.d, source, Loss.target)
-
TopLoss
: Fixes a reference discrepancyd
, a persistence diagram constructordgmFun
andtarget
(persistence diagram)val(source, Loss) = dist(Loss.d, Loss.dgmFun(source), Loss.target)
-
BarycenterStatLoss
: Fixes a reference discrepancyd
along withtargets
${\nu_1, \nu_2 \dots \nu_M}$ and theweights
${\lambda_1, \lambda_2, \dots, \lambda_n}$ val(source, Loss) =
$\sum\limits_{i=1}^{M}$ weights[i] * dist(Loss.d, source, Loss.target)^2
-
BarycenterTopLoss
: Fixes a reference discrepancyd
along withdgmFun
, the precomputed persistence diagramtargets
${D_1, D_2 \dots D_M}$ and theweights
${\lambda_1, \lambda_2, \dots, \lambda_n}$ val(source, Loss) =
$\sum\limits_{i=1}^{M}$ weights[i] * dist(Loss.d, Loss.dgmFun(source), Loss.target)^2
An AbstractBackprop
object configures the parameters for performing backpropagation for a specified AbstractLossFunction
. It dispatches different instances of the function backprop
.
The main difference between an AbstractBackprop
and AbstractGradflow
is how the gradients are computed. AbstractBackprop
methods compute the usual gradients while AbstractGradflow
computes the Wasserstein gradient (or using the JKO method).