TensorRules.jl

Macros to define custom adjoints for TensorOperations.jl
Author ho-oto
Popularity
8 Stars
Updated Last
2 Years Ago
Started In
October 2020

TensorRules.jl

Build Status Code Style: Blue

TensorRules.jl provides a macro @∇ (you can type by \nabla<tab>), which enable us to use automatic differentiation (AD) libraries (e.g., Zygote.jl, ForwardDiff.jl) with @tensor and @tensoropt macros in TensorOperations.jl.

TensorRules.jl uses ChainRulesCore.jl to define custom adjoints. So, you can use any AD libraries which supports ChainRulesCore.jl.

How to use

julia> using TensorOperations, TensorRules, Zygote;
julia> function foo(a, b, c) # define function with Einstein summation
           # d_F = \sum_{A,B,C,D} a_{A,B,C} b_{C,D,E,F} c_{A,B,D,E}
           @tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
           return d[1]
       end;
julia> a, b, c = randn(3, 4, 5), randn(5, 6, 7, 8), randn(3, 4, 6, 7);
julia> gradient(foo, a, b, c); # try to obtain gradient of `foo` by Zygote
ERROR: this intrinsic must be compiled to be called
Stacktrace:
...
julia> @∇ function foo(a, b, c) # use @∇
           @tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
           return d[1]
       end;
julia> gradient(foo, a, b, c); # it works!

How it works

The strategy of TensorRules.jl are very similar to TensorGrad.jl.

@∇ converts functions which contains @tensor or @tensoropt macro. First, @∇ detects @tensor or @tensoropt expressions in function definition and convert them to inlined functions. Then, @∇ define custom adjoint rules for the generated functions.

For example, the following definition

@∇ function foo(a, b, c, d, e, f)
    @tensoropt !C x[A, B] := conj(a[A, C]) * sin.(b)[C, D] * c.d[D, B] + d * e[1, 2][A, B]
    x = x + f
    @tensor x[A, B] += a[A, C] * (a * a)[C, B]
    return x
end

will be converted to a code equivalent to

function foo(a, b, c, d, e, f)
    x = _foo_1(a, sin.(a), c.d, d, e[1, 2])
    x = x + f
    x += _foo_2(a, a * a)
    return x
end

@inline _foo_1(x1, x2, x3, x4, x5) =
    @tensoropt !C _[A, B] := conj(x1[A, C]) * x2[C, D] * x3[D, B] + x4 * x5[A, B]

@inline _foo_2(x1, x2) = @tensor _[A, B] := x1[A, C] * x2[C, B]

function rrule(::typeof(_foo_1), x1, x2, x3, x4, x5)
    f = _foo_1(x1, x2, x3, x4, x5)
    function _foo_1_pullback(Δf)
        Δx1 = InplaceableThunk(
            Thunk(() -> @tensoropt !C Δx1[A, C] := conj(Δf[A, B]) * x2[C, D] * x3[D, B]),
            Δx1 -> @tensoropt !C Δx1[A, C] += conj(Δf[A, B]) * x2[C, D] * x3[D, B]
        )
        Δx2 = InplaceableThunk(
            Thunk(() -> @tensoropt !C Δx2[C, D] := conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])),
            Δx2 -> @tensoropt !C Δx2[C, D] += conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
        )
        Δx3 = InplaceableThunk(
            Thunk(() -> @tensoropt !C Δx3[D, B] := conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))),
            Δx3 -> @tensoropt !C Δx3[D, B] += conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
        )
        Δx4 = Thunk(() -> first(@tensoropt !C Δx4[] := conj(conj(Δf[A, B]) * x5[A, B])))
        Δx5 = InplaceableThunk(
            Thunk(() -> @tensoropt !C Δx5[A, B] := conj(x4 * conj(Δf[A, B]))),
            Δx5 -> @tensoropt !C Δx5[A, B] := conj(x4 * conj(Δf[A, B]))
        )
        return (NO_FIELDS, Δx1, Δx2, Δx3, Δx4, Δx5)
    end
    return f, _foo_1_pullback
end

function rrule(::typeof(_foo_2), x1, x2)
    ...
end

By using Thunk and InplaceableThunk properly, adjoints will be evaluated only if they are needed.

unsupported features

  • @∇ uses @capture macro defined in MacroTools.jl to parse Expr. Because of the limitation of @capture macro, index notations based on :typed_vcat and :typed_hcat (A[a; b], A[a b]) are unsupported. Please use A[a, b] style.
  • Designations of contraction order based on ord=(...) or NCON style are unsupported. Please use @tensoropt and specify costs of each bonds.
  • Since Zygote.jl does not support inplace operations, we cannot use @tensor A[] = ... in the expression. Please use :=, += and -= instead.

TODO

  • support frule
  • support @tensor block (@tensor begin ... end)
  • support higher order differentiation (by applying @∇ to rrule and frule recursively)
    • add more test (higher order differentiations are not well tested since Zygote.jl has poor support of higher order differentiation...😞)
    • better support of InplaceableThunk (in this version, when we use @∇ i foo(...) = ... where i > 1, InplaceableThunk will be disabled)
  • use @thunk ?

Used By Packages

No packages found.