One More Einsum for Julia! With runtime order-specification and high-level adjoints for AD
Author under-Peter
75 Stars
Updated Last
2 Years Ago
Started In
May 2019
OMEinsum logo

OMEinsum - One More Einsum

Stable Dev Build Status pipeline status Codecov

This is a repository for the Google Summer of Code project on Differentiable Tensor Networks. It implements one function that both computer scientists and physicists love, the Einstein summation

einsum definition

To find out the details about einsum, please check out my nextjournal-article or the numpy-manual.

Einstein summation can be implemented in no more than 20 lines of Julia code, the automatic differentiation is also straightforward. The main effort of this package is improving the performance utilizing Julia multiple dispatch on traits. So that people can enjoy the speed of faster specific implementations like BLAS functions, sum and permutedims on both CPU and GPU without suffering from runtime overhead.

Note: why the test coverage is not 100% - GPU-code coverage is not evaluated although we test the GPU code properly on gitlab. Ignoring the GPU-code, the actual coverage is at about 97%.

Warning: since v0.4, OMEinsum does not optimize the contraction order anymore. One has to use nested einsum to specify the contraction order manually, e.g. ein"(ijk,jkl),klm->im"(x, y, z).


To install, type ] in a julia (>=1.5) REPL and then input

pkg> add OMEinsum

Learn by Examples

To avoid runtime overhead, we recommend users to use non-standard string literal @ein_str. The following examples illustrates how einsum works

julia> using OMEinsum, SymEngine

julia> catty = fill(Basic(:๐Ÿฑ), 2, 2)
2ร—2 Array{Basic,2}:
 ๐Ÿฑ  ๐Ÿฑ
 ๐Ÿฑ  ๐Ÿฑ

julia> fish = fill(Basic(:๐ŸŸ), 2, 3, 2)
2ร—3ร—2 Array{Basic,3}:
[:, :, 1] =
 ๐ŸŸ  ๐ŸŸ  ๐ŸŸ
 ๐ŸŸ  ๐ŸŸ  ๐ŸŸ

[:, :, 2] =
 ๐ŸŸ  ๐ŸŸ  ๐ŸŸ
 ๐ŸŸ  ๐ŸŸ  ๐ŸŸ

julia> snake = fill(Basic(:๐Ÿ), 3, 3)
3ร—3 Array{Basic,2}:
 ๐Ÿ  ๐Ÿ  ๐Ÿ
 ๐Ÿ  ๐Ÿ  ๐Ÿ
 ๐Ÿ  ๐Ÿ  ๐Ÿ

julia> medicine = ein"ij,jki,kk->k"(catty, fish, snake)
3-element Array{Basic,1}:

julia> ein"ik,kj -> ij"(catty, catty) # multiply two matrices `a` and `b`
2ร—2 Array{Basic,2}:
 2*๐Ÿฑ^2  2*๐Ÿฑ^2
 2*๐Ÿฑ^2  2*๐Ÿฑ^2

julia> ein"ij -> "(catty)[] # sum a matrix, output 0-dimensional array

julia> ein"->ii"(asarray(snake[1,1]), size_info=IndexSize('i'=>5)) # get 5 x 5 identity matrix
5ร—5 Array{Basic,2}:
 ๐Ÿ  0  0  0  0
 0  ๐Ÿ  0  0  0
 0  0  ๐Ÿ  0  0
 0  0  0  ๐Ÿ  0
 0  0  0  0  ๐Ÿ

Alternatively, people can specify the contraction with a construction approach, which is useful when the contraction code can only be obtained at run time

julia> einsum(EinCode((('i','k'),('k','j')),('i','j')),(a,b))

or a macro based interface, @ein macro, which is closer to the standard way of writing einsum-operations in physics

julia> @ein c[i,j] := a[i,k] * b[k,j];

A table for reference

code meaning
ein"ij,jk->ik" matrix matrix multiplication
ein"ijl,jkl->ikl" batched - matrix matrix multiplication
ein"ij,j->i" matrix vector multiplication
ein"ij,ik,il->jkl" star contraction
ein"ii->" trace
ein"ij->i" sum
ein"ii->i" take the diagonal part of a matrix
ein"ijkl->ilkj" permute the dimensions of a tensor
ein"i->ii" construct a diagonal matrix
ein"->ii" broadcast a scalar to the diagonal part of a matrix
ein"ij,ij->ij" element wise product
ein"ij,kl->ijkl" outer product

Many of these are handled by special kernels (listed in the docs), but there is also a fallback which handles other cases (more like what Einsum.jl does, plus a GPU version).

It is sometimes helpful to specify the order of operations, by inserting brackets, either because you know this will be more efficient, or to help the computer see what kernels can be used. For example:

julia> @ein Z[o,s] := x[i,s] * (W[o,i,j] * y[j,s]);   # macro style

julia> Z = ein"is, (oij, js) -> os"(x, W, y);         # string style

This performs matrix multiplication (summing over j) followed by batched matrix multiplication (summing over i, batch label s). Without the brackets, instead it uses the fallback loop_einsum, which is slower. Calling allow_loops(false) will print an error to help you spot such cases:

julia> @ein Zl[o,s] := x[i,s] * W[o,i,j] * y[j,s];

julia> Z โ‰ˆ Zl

julia> allow_loops(false);

julia> Zl = ein"is, oij, js -> os"(x, W, y);
โ”Œ Error: using `loop_einsum` to evaluate
โ”‚   code = EinCode{((1, 2), (3, 1, 4), (4, 2)),(3, 2)}()
โ”‚   size.(xs) = ((10, 50), (20, 10, 10), (10, 50))
โ”‚   size(y) = (20, 50)
โ”” @ OMEinsum ~/.julia/dev/OMEinsum/src/loop_einsum.jl:26

To see more examples using the GPU and autodiff, check out our asciinema-demo here: asciicast


For an application in tensor network algorithms, check out the TensorNetworkAD package, where OMEinsum is used to evaluate tensor-contractions, permutations and summations.

Toy Application: solving a 3-coloring problem on the Petersen graph

Let us focus on graphs with vertices with three edges each. A question one might ask is: How many different ways are there to colour the edges of the graph with three different colours such that no vertex has a duplicate colour on its edges?

The counting problem can be transformed into a contraction of rank-3 tensors representing the edges. Consider the tensor s defined as

julia> s = map(x->Int(length(unique(x.I)) == 3), CartesianIndices((3,3,3)))

Then we can simply contract s tensors to get the number of 3 colourings satisfying the above condition! E.g. for two vertices, we get 6 distinct colourings:

julia> ein"ijk,ijk->"(s,s)[]

Using that method, it's easy to find that e.g. the peterson graph allows no 3 colouring, since

julia> ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->"(fill(s, 10)...)[]

The peterson graph consists of 10 vertices and 15 edges and looks like a pentagram embedded in a pentagon as depicted here:

Confronted with the above result, we can ask whether the peterson graph allows a relaxed variation of 3 colouring, having one vertex that might accept duplicate colours. The answer to that can be found using the gradient w.r.t a vertex:

julia> using Zygote: gradient

julia> gradient(x->ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->"(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum

This tells us that even if we allow duplicates on one vertex, there are no 3-colourings for the peterson graph.


Suggestions and Comments in the Issues are welcome.


MIT License