MLJFlux.jl

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
Popularity
145 Stars
Updated Last
3 Months Ago
Started In
June 2019
image

An interface to the Flux deep learning models for the MLJ machine learning framework

Stable

Branch Julia CPU CI GPU CI Coverage
master v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage
dev v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage

Code Snippet

using MLJ, MLJFlux, RDatasets, Plots

Grab some data and split into features and target:

iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), rng=123);
X = Float32.(X);      # To optmise for GPUs

Load model code and instantiate an MLJFlux model:

NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux

clf = NeuralNetworkClassifier(
    builder=MLJFlux.MLP(; hidden=(5,4)),
    batch_size=8,
    epochs=50,
    acceleration=CUDALibs()  # for training on a GPU
)

Wrap in "iteration controls":

stop_conditions = [
    Step(1),            # Apply controls every epoch
    NumberLimit(1000),  # Don't train for more than 1000 steps
    Patience(4),        # Stop after 4 iterations of deteriation in validation loss
    NumberSinceBest(5), # Or if the best loss occurred 5 iterations ago
    TimeLimit(30/60),   # Or if 30 minutes has passed
]

validation_losses = []
train_losses = []
callbacks = [
    WithLossDo(loss->push!(validation_losses, loss)),
    WithTrainingLossesDo(losses->push!(train_losses, losses[end])),
]

iterated_model = IteratedModel(
    model=clf,
    resampling=Holdout(fraction_train=0.5); # loss and stopping are based on out-of-sample
    measures=log_loss,
    controls=vcat(stop_conditions, callbacks),
);

Train the wrapped model:

julia> mach = machine(iterated_model, X, y)
julia> fit!(mach)

[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`. 
[ Info: final loss: 0.1284184007796247
[ Info: final training loss: 0.055630706
[ Info: Stop triggered by NumberSinceBest(5) stopping criterion. 
[ Info: Total of 811 iterations. 

Inspect results:

julia> plot(train_losses, label="Training Loss")
julia> plot!(validation_losses, label="Validation Loss", linewidth=2, size=(800,400))