DataLoaders
A parallel iterator for large machine learning datasets that don't fit into memory inspired by PyTorch's DataLoader
class.
It uses ThreadPools.jl
to process batches in parallel while keeping the primary thread free.
Utilizing LearnBase.jl
's Data Access Pattern, so many data containers work out of the box and custom containers are easily supported by implementing LearnBase.getobs
and LearnBase.nobs
.
Usage
Options
DataLoader(dataset, batchsize; kwargs...)
Arguments
dataset
: A data container supporting theLearnBase
data access patternbatchsize::Integer = 1
: Number of samples to batch together
Keyword arguments
shuffle::Bool = true
: Whether to shuffle the observations before iteratingnumworkers::Integer = max(1, Threads.nthreads() - 1)
: Number of workers to spawn to load data in parallel. The primary thread is kept free.transformfn
: Function that is applied to individual samples before batchingcollatefn
: Function that collates multiple samples into a batch. For default behavior, seecollate
droplast::Bool = false
: Whether to drop the last batch whennobs(dataset)
is not divisible bybatchsize
.true
ensures all batches have the same size, but some samples might be dropped
Simple example
dataset = ([1, 2, 3, 4, 5, 6], [2, 4, 6, 8, 10, 12])
dataloader = DataLoader(dataset, 2, numworkers = 1)
for batch in dataloader
# do your stuff
end
Note: if your dataset fits into memory like in this toy example, you don't need parallel workers
Custom data container example
using Images: load
import LearnBase: getobs, nobs
# Custom data container
struct ImageDataset
files::AbstractVector{AbstractString}
end
# Implementing `LearnBase.jl` interface
getobs(ds::ImageDataset, idx::Integer) = load(ds.files[idx])
nobs(ds::ImageDataset) = length(ds.files)
# Now you can use your custom container
dataset = ImageDataset(["image1.jpg", "image2.jpg", ...])
dataloader = DataLoader(dataset, 8, shuffle = true)
for batch in dataloader
# do your stuff
end
Note: To use multiple workers (default behavior), you have to set the JULIA_NUM_THREADS
environment variable before starting your session.