TensorRules.jl
provides a macro @∇
(you can type ∇
by \nabla<tab>
), which
enable us to use automatic differentiation (AD) libraries (e.g.,
Zygote.jl
,
Diffractor.jl
)
with @tensor
and @tensoropt
macros in TensorOperations.jl
.
TensorRules.jl
uses ChainRulesCore.jl
to define custom adjoints.
So, you can use any AD libraries which supports ChainRulesCore.jl
.
julia> using TensorOperations, TensorRules, Zygote;
julia> function foo(a, b, c) # define function with Einstein summation
# d_F = \sum_{A,B,C,D} a_{A,B,C} b_{C,D,E,F} c_{A,B,D,E}
@tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
return d[1]
end;
julia> a, b, c = randn(3, 4, 5), randn(5, 6, 7, 8), randn(3, 4, 6, 7);
julia> gradient(foo, a, b, c); # try to obtain gradient of `foo` by Zygote
ERROR: this intrinsic must be compiled to be called
Stacktrace:
...
julia> @∇ function foo(a, b, c) # use @∇
@tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
return d[1]
end;
julia> gradient(foo, a, b, c); # it works!
The strategy of TensorRules.jl
are very similar to TensorGrad.jl
.
@∇
converts functions which contains @tensor
or @tensoropt
macro.
First, @∇
detects @tensor
or @tensoropt
expressions in function definition
and convert them to inlined functions.
Then, @∇
define custom adjoint rules for the generated functions.
For example, the following definition
@∇ function foo(a, b, c, d, e, f)
@tensoropt !C x[A, B] := conj(a[A, C]) * sin.(b)[C, D] * c.d[D, B] + d * e[1, 2][A, B]
x = x + f
@tensor x[A, B] += a[A, C] * (a * a)[C, B]
return x
end
will be converted to a code equivalent to
function foo(a, b, c, d, e, f)
x = _foo_1(a, sin.(a), c.d, d, e[1, 2])
x = x + f
x += _foo_2(a, a * a)
return x
end
@inline _foo_1(x1, x2, x3, x4, x5) =
@tensoropt !C _[A, B] := conj(x1[A, C]) * x2[C, D] * x3[D, B] + x4 * x5[A, B]
@inline _foo_2(x1, x2) = @tensor _[A, B] := x1[A, C] * x2[C, B]
function rrule(::typeof(_foo_1), x1, x2, x3, x4, x5)
f = _foo_1(x1, x2, x3, x4, x5)
Px1, Px2, Px3, Px4, Px5 = ProjectTo(x1), ProjectTo(x2), ProjectTo(x3), ProjectTo(x4), ProjectTo(x5)
function _foo_1_pullback(Δf)
fnΔx1(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, C] := conj(Δf[A, B]) * x2[C, D] * x3[D, B]
fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, C] += conj(Δf[A, B]) * x2[C, D] * x3[D, B]
fnΔx2(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[C, D] := conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[C, D] += conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
fnΔx3(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[D, B] := conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[D, B] += conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
fnΔx4(Δf, x1, x2, x3, x4, x5) = first(@tensoropt !C _[] := conj(conj(Δf[A, B]) * x5[A, B]))
fnΔx5(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, B] := conj(x4 * conj(Δf[A, B]))
fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, B] += conj(x4 * conj(Δf[A, B]))
Δx1 = InplaceableThunk(
Thunk(() -> Px1(fnΔx1(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx2 = InplaceableThunk(
Thunk(() -> Px2(fnΔx2(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx3 = InplaceableThunk(
Thunk(() -> Px3(fnΔx3(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx4 = Thunk(() -> fnΔx4(Δf, x1, x2, x3, x4, x5))
Δx5 = InplaceableThunk(
Thunk(() -> Px5(fnΔx5(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5)
)
return (NoTangent(), Δx1, Δx2, Δx3, Δx4, Δx5)
end
return f, _foo_1_pullback
end
function rrule(::typeof(_foo_2), x1, x2)
...
end
By using Thunk
and InplaceableThunk
properly, adjoints will be evaluated only
if they are needed.
@∇
uses@capture
macro defined inMacroTools.jl
to parseExpr
. Because of the limitation of@capture
macro, index notations based on:typed_vcat
and:typed_hcat
(A[a; b], A[a b]
) are unsupported. Please useA[a, b]
style.- Designations of contraction order based on
ord=(...)
or NCON style are unsupported. Please use@tensoropt
and specify costs of each bonds. - Since
Zygote.jl
does not support inplace operations, we cannot use@tensor A[] = ...
in the expression. Please use:=
,+=
and-=
instead.
- support
@tensor
block (@tensor begin ... end
) - support higher order differentiation (by applying
@∇
torrule
andfrule
recursively)- better support of
InplaceableThunk
- better support of