While Julia is great, there are still a lot of existing useful differentiable python code in PyTorch, Jax, etc. Given PyCall.jl is already so great and seamless, one might wonder what it takes to differentiate through those pycalls. This library aims for that ideal.

Thanks to @pabloferz, this works on both CPU and GPU without any array copies via DLPack.jl.

Basic Usage


CPU only

Install Python dependencies
using PyCall
run(`$(PyCall.pyprogramname) -m pip install torch==1.11.0+cpu -f functorch`)
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using Zygote

indim = 32
outdim = 16
torch_module = torch.nn.Linear(indim, outdim) # Can be anything subclassing torch.nn.Module
jlwrap = TorchModuleWrapper(torch_module)

batchsize = 64
input = randn(Float32, indim, batchsize)
output = jlwrap(input)

target = randn(Float32, outdim, batchsize)
loss(m, x, y) = sum(m(x) .- target)
grad, = Zygote.gradient(m->loss(m, input, target), jlwrap)


Install Python dependencies
using PyCall
# For CUDA 11 and PyTorch 1.11
run(`$(PyCall.pyprogramname) -m pip install torch==1.11.0+cu113 -f functorch`)
using CUDA
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using Zygote

@assert CUDA.functional()

indim = 32
outdim = 16
torch_module = torch.nn.Linear(indim, outdim).to(device=torch.device("cuda:0")) # Can be anything subclassing torch.nn.Module
jlwrap = TorchModuleWrapper(torch_module)

batchsize = 64
input =, indim, batchsize))
output = jlwrap(input)

target =, outdim, batchsize))
loss(m, x, y) = sum(m(x) .- y)
grad, = Zygote.gradient(m->loss(m, input, target), jlwrap)


CPU only

Install Python dependencies
using PyCall
run(`$(PyCall.pyprogramname) -m pip install jax\["cpu"\]`) # for cpu version
using PyCallChainRules.Jax: JaxFunctionWrapper, jax, stax, pyto_dlpack

batchsize = 64
indim = 32
outdim = 16

init_lin, apply_lin = stax.Dense(outdim)
_, params = init_lin(jax.random.PRNGKey(0), (-1, indim))
params_jl = map(x->DLPack.wrap(x, pyto_dlpack), params)
jlwrap = JaxFunctionWrapper(jax.jit(apply_lin))
input = randn(Float32, indim, batchsize)
output = jlwrap(params_jl, input)

target = randn(Float32, outdim, batchsize)
loss(p, x, y) = sum(jlwrap(p, x) .- y)
grad, = Zygote.gradient(p->loss(p, input, target), params_jl)


Install Python dependencies
using PyCall
run(`$(PyCall.pyprogramname) -m pip install jax\["cuda"\] -f`)
using PyCallChainRules.Jax: JaxFunctionWrapper, jax, stax
using CUDA

using PyCallChainRules.Jax: JaxFunctionWrapper, jax, stax, pyto_dlpack

batchsize = 64
indim = 32
outdim = 16

init_lin, apply_lin = stax.Dense(outdim)
_, params = init_lin(jax.random.PRNGKey(0), (-1, indim))
params_jl = map(x->DLPack.wrap(x, pyto_dlpack), params)
jlwrap = JaxFunctionWrapper(jax.jit(apply_lin))
input =, indim, batchsize))
output = jlwrap(params_jl, input)

target =, outdim, batchsize))
loss(p, x, y) = sum(jlwrap(p, x) .- y)
grad, = Zygote.gradient(p->loss(p, input, target), params_jl)

When mixing jax and julia it's recommended to disable jax's preallocation with setting the environment variable XLA_PYTHON_CLIENT_PREALLOCATE=false.

Current Limitations

  • Input and output types of wrapped python functions can only be python tensors or [nested] tuples of python tensors.
  • Keyword arguments should not be arrays and do not support differentiation.

