Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
115 Stars
Updated Last
1 Year Ago
Started In
June 2019


An interface to the Flux deep learning models for the MLJ machine learning framework.

Branch Julia CPU CI GPU CI Coverage
master v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage
dev v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage

MLJFlux makes it possible to apply the machine learning meta-algorithms provided by MLJ - such as out-of-sample performance evaluation, hyper-parameter optimization, and iteration control - to some classes of supervised deep learning models. It does this by providing an interface to the Flux framework.

The guiding vision of this package is to make evaluating and optimizing basic Flux models more convenient to users already familiar with the MLJ workflow. This goal will likely place restrictions of the class of Flux models that can used, at least in the medium term. For example, online learning, re-enforcement learning, and adversarial networks are currently out of scope.

Currently MLJFlux is also limited to training models in the case that all training data fits into memory.

Basic idea

Each MLJFlux model has a builder hyperparameter, an object encoding instructions for creating a neural network given the data that the model eventually sees (e.g., the number of classes in a classification problem). While each MLJ model has a simple default builder, users will generally need to define their own builders to get good results, and this will require familiarity with the Flux API for defining a neural network chain.

In the future MLJFlux may provide a larger assortment of canned builders. Pull requests introducing new ones are most welcome.


using Pkg
Pkg.activate("my_environment", shared=true)
Pkg.add("RDatasets")  # for the demo below


Following is an introductory example using a default builder and no standardization of input features (notebook/script).

For an example implementing early stopping and snapshots, using MLJ's IteratedModel wrapper, see the MNIST dataset example.

Loading some data and instantiating a model

using MLJ
import RDatasets
iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), colname -> true, rng=123);
NeuralNetworkClassifier = @load NeuralNetworkClassifier

julia> clf = NeuralNetworkClassifier()
	builder = Short(
			n_hidden = 0,
			dropout = 0.5,
			σ = NNlib.σ),
	finaliser = NNlib.softmax,
	optimiser = ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}()),
	loss = Flux.crossentropy,
	epochs = 10,
	batch_size = 1,
	lambda = 0.0,
	alpha = 0.0,
	optimiser_changes_trigger_retraining = false) @ 160

Incremental training

import Random.seed!; seed!(123)
mach = machine(clf, X, y)

julia> training_loss = cross_entropy(predict(mach, X), y) |> mean

# Increasing learning rate and adding iterations:
clf.optimiser.eta = clf.optimiser.eta * 2
clf.epochs = clf.epochs + 5

julia> fit!(mach, verbosity=2)
[ Info: Updating Machine{NeuralNetworkClassifier{Short,…},…} @804.
[ Info: Loss is 0.8686
[ Info: Loss is 0.8228
[ Info: Loss is 0.7706
[ Info: Loss is 0.7565
[ Info: Loss is 0.7347
Machine{NeuralNetworkClassifier{Short,…},…} @804 trained 2 times; caches data
	1:  Source @985`Table{AbstractVector{Continuous}}`
	2:  Source @367`AbstractVector{Multiclass{3}}`

julia> training_loss = cross_entropy(predict(mach, X), y) |> mean

Accessing the Flux chain (model)

julia> fitted_params(mach).chain
Chain(Chain(Dense(4, 3, σ), Flux.Dropout{Float64}(0.5, false), Dense(3, 3)), softmax)

Evolution of out-of-sample performance

r = range(clf, :epochs, lower=1, upper=200, scale=:log10)
curve = learning_curve(clf, X, y,
using Plots
	   ylab = "Cross Entropy")


In MLJ a model is a mutable struct storing hyperparameters for some learning algorithm indicated by the model name, and that's all. In particular, an MLJ model does not store learned parameters.

Warning: In Flux the term "model" has another meaning. However, as all Flux "models" used in MLJFLux are Flux.Chain objects, we call them chains, and restrict use of "model" to models in the MLJ sense.

MLJFlux provides four model types, for use with input features X and targets y of the scientific type indicated in the table below. The parameters n_in, n_out and n_channels refer to information passed to the builder, as described under Defining a new builder below.

model type prediction type scitype(X) <: _ scitype(y) <: _
NeuralNetworkRegressor Deterministic Table(Continuous) with n_in columns AbstractVector{<:Continuous) (n_out = 1)
MultitargetNeuralNetworkRegressor Deterministic Table(Continuous) with n_in columns <: Table(Continuous) with n_out columns
NeuralNetworkClassifier Probabilistic <:Table(Continuous) with n_in columns AbstractVector{<:Finite} with n_out classes
ImageClassifier Probabilistic AbstractVector(<:Image{W,H}) with n_in = (W, H) AbstractVector{<:Finite} with n_out classes

Table 1. Input and output types for MLJFlux models

Non-tabular input

Any AbstractMatrix{<:AbstractFloat} object Xmat can be forced to have scitype Table(Continuous) by replacing it with X = MLJ.table(Xmat). Furthermore, this wrapping, and subsequent unwrapping under the hood, will compile to a no-op. At present this includes support for sparse matrix data, but the implementation has not been optimized for sparse data at this time and so should be used with caution.

Instructions for coercing common image formats into some AbstractVector{<:Image} are here.

Warm restart

MLJ machines cache state enabling the "warm restart" of model training, as demonstrated in the example above. In the case of MLJFlux models, fit!(mach) will use a warm restart if:

  • only model.epochs has changed since the last call; or

  • only model.epochs or model.optimiser have changed since the last call and model.optimiser_changes_trigger_retraining == false (the default) (the "state" part of the optimiser is ignored in this comparison). This allows one to dynamically modify learning rates, for example.

Here model=mach.model is the associated MLJ model.

The warm restart feature makes it possible to apply early stopping criteria, as defined in EarlyStopping.jl. For an example, see /examples/mnist/. (Eventually, this will be handled by an MLJ model wrapper for controlling arbitrary iterative models.)

Training on a GPU

When instantiating a model for training on a GPU, specify acceleration=CUDALibs(), as in

using MLJ
ImageClassifier = @load ImageClassifier
model = ImageClassifier(epochs=10, acceleration=CUDALibs())
mach = machine(model, X, y) |> fit!

In this example, the data X, y is copied onto the GPU under the hood on the call to fit! and cached for use in any warm restart (see above). The Flux chain used in training is always copied back to the CPU at then conclusion of fit!, and made available as fitted_params(mach).

Random number generators and reproducibility

Every MLJFlux model includes an rng hyper-parameter that is passed to builders for the purposes of weight initialization. This can be any AbstractRNG or the seed (integer) for a MersenneTwister that will be reset on every cold restart of model (machine) training.

Until there is a mechanism for doing so rng is not passed to dropout layers and one must manually seed the GLOBAL_RNG for reproducibility purposes, when using a builder that includes Dropout (such as MLJFlux.Short). If training models on a GPU (i.e., acceleration isa CUDALibs) one must additionally call CUDA.seed!(...).

Built-in builders

The following builders are provided out-of-the-box. Query their doc-strings for advanced options and further details.

builder description
MLJFlux.Linear(σ=relu) vanilla linear network with activation function σ
MLJFlux.Short(n_hidden=0, dropout=0.5, σ=sigmoid) fully connected network with one hidden layer and dropout
MLJFlux.MLP(hidden=(10,)) general multi-layer perceptron

Model hyperparameters.

All models share the following hyper-parameters:

  1. builder: Default = MLJFlux.Linear(σ=Flux.relu) (regressors) or MLJFlux.Short(n_hidden=0, dropout=0.5, σ=Flux.σ) (classifiers)

  2. optimiser: The optimiser to use for training. Default = Flux.ADAM()

  3. loss: The loss function used for training. Default = Flux.mse (regressors) and Flux.crossentropy (classifiers)

  4. n_epochs: Number of epochs to train for. Default = 10

  5. batch_size: The batch_size for the data. Default = 1

  6. lambda: The regularization strength. Default = 0. Range = [0, ∞)

  7. alpha: The L2/L1 mix of regularization. Default = 0. Range = [0, 1]

  8. rng: The random number generator (RNG) passed to builders, for weight intitialization, for example. Can be any AbstractRNG or the seed (integer) for a MersenneTwister that is reset on every cold restart of model (machine) training. Default = GLOBAL_RNG.

  9. acceleration: Use CUDALibs() for training on GPU; default is CPU1().

  10. optimiser_changes_trigger_retraining: True if fitting an associated machine should trigger retraining from scratch whenever the optimiser changes. Default = false

The classifiers have an additional hyperparameter finaliser (default = Flux.softmax) which is the operation applied to the unnormalized output of the final layer to obtain probabilities (outputs summing to one). Default = Flux.softmax. It should return a vector of the same length as its input.

Defining a new builder

Following is an example defining a new builder for creating a simple fully-connected neural network with two hidden layers, with n1 nodes in the first hidden layer, and n2 nodes in the second, for use in any of the first three models in Table 1. The definition includes one mutable struct and one method:

mutable struct MyBuilder <: MLJFlux.Builder
	n1 :: Int
	n2 :: Int

function, rng, n_in, n_out)
	init = Flux.glorot_uniform(rng)
	return Chain(Dense(n_in, nn.n1, init=init),
				 Dense(nn.n1, nn.n2, init=init),
				 Dense(nn.n2, n_out, init=init))

Note here that n_in and n_out depend on the size of the data (see Table 1).

For a concrete image classification example, see examples/mnist.

More generally, defining a new builder means defining a new struct sub-typing MLJFlux.Builder and defining a new method with one of these signatures:, rng, n_in, n_out), rng, n_in, n_out, n_channels) # for use with `ImageClassifier`

This method must return a Flux.Chain instance, chain, subject to the following conditions:

  • chain(x) must make sense:

    • for any x <: Array{<:AbstractFloat, 2} of size (n_in, batch_size) where batch_size is any integer (for use with one of the first three model types); or

    • for any x <: Array{<:Float32, 4} of size (W, H, n_channels, batch_size), where (W, H) = n_in, n_channels is 1 or 3, and batch_size is any integer (for use with ImageClassifier)

  • The object returned by chain(x) must be an AbstractFloat vector of length n_out.

Alternatively, use MLJFlux.@builder(neural_net) to automatically create a builder for any valid Flux chain expression neural_net, where the symbols n_in, n_out, n_channels and rng can appear literally, with the interpretations explained above. For example,

builder = MLJFlux.@builder Chain(Dense(n_in, 128), Dense(128, n_out, tanh))

Loss functions

Currently, the loss function specified by loss=... is applied internally by Flux and needs to conform to the Flux API. You cannot, for example, supply one of MLJ's probabilistic loss functions, such as MLJ.cross_entropy to one of the classifier constructors, although you should use MLJ loss functions in MLJ meta-algorithms.

An image classification example

An expanded version of this example, with early stopping and snapshots, is available here.

We define a builder that builds a chain with six alternating convolution and max-pool layers, and a final dense layer, which we apply to the MNIST image dataset.

First we define a generic builder (working for any image size, color or gray):

using MLJ
using Flux
using MLDatasets

# helper function
function flatten(x::AbstractArray)
	return reshape(x, :, size(x)[end])

import MLJFlux
mutable struct MyConvBuilder

function, rng, n_in, n_out, n_channels)

	k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3

	mod(k, 2) == 1 || error("`filter_size` must be odd. ")

	# padding to preserve image size on convolution:
	p = div(k - 1, 2)

	front = Chain(
			   Conv((k, k), n_channels => c1, pad=(p, p), relu),
			   MaxPool((2, 2)),
			   Conv((k, k), c1 => c2, pad=(p, p), relu),
			   MaxPool((2, 2)),
			   Conv((k, k), c2 => c3, pad=(p, p), relu),
			   MaxPool((2 ,2)),
	d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
	return Chain(front, Dense(d, n_out))

Next, we load some of the MNIST data and check scientific types conform to those is the table above:

N = 500
Xraw, yraw = MNIST.traindata();
Xraw = Xraw[:,:,1:N];
yraw = yraw[1:N];

julia> scitype(Xraw)
AbstractArray{Unknown, 3}

julia> scitype(yraw)

Inputs should have element scitype GrayImage:

X = coerce(Xraw, GrayImage);

For classifiers, target must have element scitype <: Finite:

y = coerce(yraw, Multiclass);

Instantiating an image classifier model:

ImageClassifier = @load ImageClassifier
clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32),

And evaluating the accuracy of the model on a 30% holdout set:

mach = machine(clf, X, y)

julia> evaluate!(mach,
				 resampling=Holdout(rng=123, fraction_train=0.7),
│ _.measure              │ _.measurement │ _.per_fold │
│ misclassification_rate │ 0.0467        │ [0.0467]   │

Adding new models to MLJFlux (advanced)

This section is mainly for MLJFlux developers. It assumes familiarity with the MLJ model API

If one subtypes a new model type as either MLJFlux.MLJFluxProbabilistic or MLJFlux.MLJFluxDeterministic, then instead of defining new methods for and MLJModelInterface.update one can make use of fallbacks by implementing the lower level methods shape, build, and fitresult. See the classifier source code for an example.

One still needs to implement a new predict method.