## Unitary.jl

A differentiable parametrization of a group of unitary matrices.
Author pevnak
Popularity
14 Stars
Updated Last
2 Years Ago
Started In
June 2019

# Unitary.jl

This package implements a differentiable parametrization of a group of unitary matrices as described in paper Sum-Product-Transform Networks: Exploiting Symmetries using Invertible Transformations, Tomas Pevny, Vasek Smidl, Martin Trapp, Ondrej Polacek, Tomas Oberhuber, 2020 https://arxiv.org/abs/2005.01297

The actual "Dense" node implementing `f(x) = σ.(W * x .+ b)`, where `W` is in svd form has moved to https://github.com/pevnak/SumProductTransform.jl to keep this simple. Since in the paper, we have experimented with different ways, how to efficiently implement Dense matrices featuring efficient inversion and calculation of determinant, the repository contains a little bit more.

• `Givens` - representation of a unitary matrix using Givens rotations
• `UnitaryHouseholder` - representation of a unitary matrix using Householder reflections, an approach common in Machine Learning
• LU - representation of a matrix using LU decomposition
• LDU - representation of a matrix using LDU decomposition

The usage is simple:

``````using Unitary, Flux, BenchmarkTools
using Unitary: Givens, lowup

x = randn(Float32, 50, 100)
xx = randn(Float32, 100, 50)

a = Givens(50)
@btime a * x;
#  224.097 μs (4 allocations: 20.00 KiB)
@btime xx * a;
#  79.517 μs (4 allocations: 20.00 KiB)

ps = Flux.params(a)
@btime gradient(() -> sum(a * x), ps);	# 890.323 μs (58 allocations: 71.52 KiB)
# 891.481 μs (60 allocations: 72.42 KiB)
@btime gradient(() -> sum(xx * a), ps);	# 473.158 μs (58 allocations: 71.52 KiB)
@ 468.794 μs (60 allocations: 72.42 KiB)

a = Givens(50)
@btime a * x;
# 646.874 μs (10154 allocations: 2.37 MiB)

@btime xx * a;
#  726.198 μs (10204 allocations: 2.39 MiB)

@btime gradient(() -> sum(a * x), ps);
#  103.869 ms (44538 allocations: 179.60 MiB)

@btime gradient(() -> sum(xx * a), ps);
#  105.061 ms (44688 allocations: 179.67 MiB)
``````

Matrices support only multiplication, because that is what they have been designed for, but you can always convert them to normal matrices using `Matrix` (but this is not at the moment differentiable).

### Required Packages

View all packages

### Used By Packages

No packages found.