Flexible and performant GEMM kernels in Julia
This package contains a framework to instantiate flexible, performant GEMM (General Matrix Multiplication) kernels. You can use this framework to define your own GEMM kernels, or use one of the predefined interfaces that this package also provides.
The package can be installed using Julia's built-in package manager.
Open the Julia REPL, type ]
to enter Pkg-mode, and run:
pkg> add GemmKernels
Most people will be interested in the BLAS-like interface that is available as
GemmKernels.mul!
:
julia> using GemmKernels, CUDA
julia> A = CUDA.rand(2048, 2048)
julia> B = CUDA.rand(2048, 2048)
julia> C = CUDA.zeros(2048, 2048)
julia> GemmKernels.mul!(C, A, B)
For more control, e.g. to use optimized layouts, or fuse the multiplication with a bias, you
need to use the low-level GemmKernels.matmul
interface (see the examples
directory).
The kernels in this package are expected to deliver around 50% to 80% of the performance of the state-of-the-art libraries like cuBLAS and CUTLASS. The exact performance depends on the specific invocation (e.g. the size of the matrices, the data type, etc.), and the GPU architecture.
For example, on an NVIDIA RTX 2080 Ti, we can achieve competitive performance for a mixed-precision multiplication of FP16 inputs and FP32 output:
The GEMM kernels above are implemented using a framework that decomposes GEMM kernels into orthogonal components:
- Params determine the tiling size and launch configuration of the GEMM kernel. The tiling sizes are specified in logical coordinates, i.e. with a meaning specified by the user.
- Layouts convert the logical coordinates of tiles to physical offsets in memory.
- Transforms are used to apply any arbitrary Julia functor to the GEMM's inputs or outputs. They are applied after every load, and before every store.
- Operators are responsible to perform the matrix multiplication itself. They load tiles from shared memory, perform the matrix multiplication, and store the resultant tile back to shared memory.
- Epilogues copy tiles of the resultant matrix to global memory, and can be used to implement arbitrary post-processing, such as adding a bias vector to the resultant matrix.
Each of these components corresponds to a set of functions with a predetermined interface. These functions can be customised by the user through Julia's multiple dispatch functionality.
The package currently provides two main operators, both of which for NVIDIA GPUs:
WMMAOperator
: for using Tensor cores through the WMMA APIs;FPUOperator
: for other data types or input sizes.
Optimized layouts are available for diagonal matrices and matrices of complex/dual numbers.
For more details on the implementation and performance results, please see our accompanying
paper (pre-print available on arXiv). The
CITATION.bib
file in the root of this repository contains a citation in
BibTeX format.