SymbolicNeuralNetworks.jl

Analytic neural networks based on Symbolics.jl
Author JuliaGNI
Popularity
3 Stars
Updated Last
10 Months Ago
Started In
July 2023

SymbolicNeuralNetworks.jl

Stable Latest Build Status Coverage PkgEval

SymbolicNeuralNetworks.jl was created to take advantage of Symbolics.jl for training neural networks by accelarating their evaluation and by simplifing the computation of some derivatives of the neural network that may be needed for loss functions. This package is based on AbstractNeuralNetwork.jl and can be applied to GeometricMachineLearning.jl.

To accelerate the evaluation of the neural network, we change its evaluation method with its code generated by Symbolics.jl, performs some otpmizations on it, and generate the associate function with RuntimeGeneratedFunctions.jl.

One can easily symbolize its neural network which will create another neural networks with the symbolize method

symbolize(neuralnet, dim)

where neuralnet is a neural network in the framework of AbstractNeuralNetwork.jl and dim the dimension of the input.

Example

using SymbolicNeuralNetworks
using GeometricMachineLearning
using Symbolics

@variables sx[1:2]
@variables nn(sx)[1:1]
Dx1 = Differential(sx[1])
Dx2 = Differential(sx[2])
vectorfield = [0 1; -1 0] * [Dx1(nn[1]), Dx2(nn[1])]
eqs = (x = sx, nn = nn, vectorfield = vectorfield)

arch = HamiltonianNeuralNetwork(2)
shnn = SymbolicNeuralNetwork(arch; eqs = eqs)

hnn = NeuralNetwork(arch, Float64)
fun_vectorfield = functions(shnn).vectorfield

Performance

Let see the performance to compute the vectorfield between SymbolicNeuralNetwork's version and Zygote's one:

using Zygote

ω∇ₓnn(x, params) = [0 1; -1 0] * Zygote.gradient(x->hnn(x, params)[1], x)[1]

println("Comparison of performances between Zygote and SymbolicNeuralNetwork for ω∇ₓnn")
x = [0.5, 0.8]
@time ω∇ₓnn(x, hnn.params)[1]
@time fun_vectorfield(x, hnn.params)

Let see another example of the training of a SympNet (an intrasec structure preserving architecture present in GeometricMachineLearning.jl) on an harmonic oscillator the data of which come from GeometricProblem.jl :