🚜 Documentation is still WIP. Feel free to open a PR if you have any quetions. 🚧
TDAOpt.jl is a Julia package to optimize statistical and topological loss functions defined on point-clouds and functions defined on fixed grids.
From the Julia REPL, type ] to enter the Pkg REPL mode and run
pkg> add https://github.com/sidv23/TDAOpt.jlThere are 3 (and a haf) main functions exported by the current API.
dist(): computes a distance between two point-clouds/measuresval(): evaluates a loss functional on a source point-cloud/measurebackprop(): performs minimization of loss functional starting from a source point-cloud/measure∇w: computes the Wasserstein gradient of a loss functional (used inbackpropwhen specified)
These functions belong to the following parts of the current API:
A discrepancy is a struct which configures the parameters for measuring a distance between two input matrices (x and y), or two input measures (μ and ν). The ones currently implemented are:
- Statistical (MMD, Sinkhorn)
- Topological (Wasserstein Matching)
Every discrepancy should dispatch on (i.e. extend) the function dist
Losses are structs which define methods of computing loss functionals. The abstract supertype is AbstractLossFunction. Every loss is expected to extend the function val which takes in a source (i.e. matrix x or measure μ) and an AbstractLossFunction, and evaluates the loss functional
Currently, the implemented AbstractLossFunctions fall into the following categories:
-
StatLoss: Fixes a reference discrepancydandtarget(matrix or measure)val(source, Loss) = dist(Loss.d, source, Loss.target) -
TopLoss: Fixes a reference discrepancyd, a persistence diagram constructordgmFunandtarget(persistence diagram)val(source, Loss) = dist(Loss.d, Loss.dgmFun(source), Loss.target) -
BarycenterStatLoss: Fixes a reference discrepancydalong withtargets${\nu_1, \nu_2 \dots \nu_M}$ and theweights${\lambda_1, \lambda_2, \dots, \lambda_n}$ val(source, Loss) =$\sum\limits_{i=1}^{M}$ weights[i] * dist(Loss.d, source, Loss.target)^2 -
BarycenterTopLoss: Fixes a reference discrepancydalong withdgmFun, the precomputed persistence diagramtargets${D_1, D_2 \dots D_M}$ and theweights${\lambda_1, \lambda_2, \dots, \lambda_n}$ val(source, Loss) =$\sum\limits_{i=1}^{M}$ weights[i] * dist(Loss.d, Loss.dgmFun(source), Loss.target)^2
An AbstractBackprop object configures the parameters for performing backpropagation for a specified AbstractLossFunction. It dispatches different instances of the function backprop.
The main difference between an AbstractBackprop and AbstractGradflow is how the gradients are computed. AbstractBackprop methods compute the usual gradients while AbstractGradflow computes the Wasserstein gradient (or using the JKO method).