Caution
This package should be considered deprecated and won't receive any updates. Distributed Training will become a native feature for Lux, so it makes little sense for me to maintain an additional package that does the same thing. Track LuxDL/Lux.jl#494 for furthur updates.
Distributed Data Parallel Training of Neural Networks
Stable release:
] add FluxMPI
Latest development version:
] add FluxMPI#main
using CUDA, FluxMPI, Lux, Optimisers, Random, Zygote
FluxMPI.Init()
CUDA.allowscalar(false)
model = Chain(Dense(1 => 256, tanh), Dense(256 => 512, tanh), Dense(512 => 256, tanh),
Dense(256 => 1))
rng = Random.default_rng()
Random.seed!(rng, local_rank())
ps, st = Lux.setup(rng, model) .|> gpu
ps = FluxMPI.synchronize!(ps; root_rank = 0)
st = FluxMPI.synchronize!(st; root_rank = 0)
x = rand(rng, 1, 16) |> gpu
y = x .^ 2
opt = DistributedOptimizer(Adam(0.001f0))
st_opt = Optimisers.setup(opt, ps)
loss(p) = sum(abs2, model(x, p, st)[1] .- y)
st_opt = FluxMPI.synchronize!(st_opt; root_rank = 0)
gs_ = gradient(loss, ps)[1]
Optimisers.update(st_opt, ps, gs_)
t1 = time()
for epoch in 1:100
global ps, st_opt
l, back = Zygote.pullback(loss, ps)
FluxMPI.fluxmpi_println("Epoch $epoch: Loss $l")
gs = back(one(l))[1]
st_opt, ps = Optimisers.update(st_opt, ps, gs)
end
FluxMPI.fluxmpi_println(time() - t1)
Run the code using mpiexecjl -n 3 julia --project=. <filename>.jl
.
- Deep Equilibrium Models -- Deep Implicit Neural Networks & Infinite Time Neural ODEs
- ImageNet Training with Lux.jl
We follow the Lux Style Guide. All contributions must adhere to this style guide.
- Dropped support for MPI v0.19.
FLUXMPI_DISABLE_CUDAMPI_SUPPORT
is no longer used. Instead useFluxMPI.disable_cudampi_support()
to setup a LocalPreferences.toml file.clean_(print/println)
functions are nowfluxmpi_(print/println)
.
- Dropped support for
LearnBase
, akaDataLoaders.jl
.DistributedDataContainer
is now the only compatible withMLUtils.jl
. DistributedOptimiser
name changed toDistributedOptimizer
.
- Introduces a new API for gradient synchronization
- Don't wrap in
DistributedOptimiser
- Instead just add a line
allreduce_gradients(gs::NamedTuple)
- Don't wrap in
- Internal
MPIExtensions
functions renamedAllreduce!
-->allreduce!
Bcast!
-->bcast!
Reduce!
-->reduce!
- CUDA-unaware MPI bug resolved LuxDL/Lux.jl#18
- Disable CUDA-aware MPI support from
FluxMPI
usingFLUXMPI_DISABLE_CUDAMPI_SUPPORT=true
- Temporarily re-added dependencies on
MLDataUtils
andLearnBase
to ensureDataLoaders.jl
still works -- This will be dropped in a future release
DistributedOptimiser
no longer averages the gradients. Instead, the values are summed across the processes. To ensure averaging divide the loss bytotal_workers()
rrule
s andfrule
s defined forlocal_rank()
andtotal_workers
-- they can now be safely used inside loss functions.
fluxmpi_print
andfluxmpi_println
print the current time even ifFluxMPI
has not been initialized.- Calling
local_rank
ortotal_workers
beforeFluxMPI.Init
doesn't lead to a segfault. Rather we throw an error. MLDataUtils
andLearnBase
dependencies have been dropped (See #17)Zygote
andFlux
dependencies have been removed- No dispatch for
FluxMPI.synchronize!
is now available forZygote.Params
. Instead users should be manually broadcasting the function overZygote.Params
- No dispatch for
broadcast_parameters
has been renamed toFluxMPI.synchronize!
since it synchronizes a lot more than trainable parameters now.- DistributedOptimiser is no longer tied with Flux. We can essentially deal with any training as long as it is compatible with Optimisers.jl