This package provides a simple interface for defining, training, and deploying Mixture Density Networks (MDNs). MDNs were first proposed by Bishop (1994). We can think of an MDN as a specialized type of Artificial Neural Network (ANN), which takes some features X
and returns a distribution over the labels Y
under a Gaussian Mixture Model (GMM). Unlike an ANN, MDNs maintain the full conditional distribution P(Y|X). This makes them particularly well-suited for situations where we want to maintain some measure of the uncertainty in our predictions. Moreover, because GMMs can represent multimodal distributions, MDNs are capable of modelling one-to-many relationships, which occurs when each input X
can be associated with more than one output Y
.
This package implements the interface specified by MLJModelInterface and is thus fully compatible with the MLJ ecosystem. Below is an example demonstrating the use of this package in conjunction with MLJ.
using Flux, MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers
const n_samples = 1000
const epochs = 1000
const batchsize = 128
const mixtures = 8
const layers = [128, 128]
function main()
# Generate Data
X, Y = generate_data(n_samples)
# Create Model
model = MixtureDensityNetwork(1, 1, layers, mixtures)
# Fit Model
model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)
# Plot Learning Curve
fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
save("LearningCurve.png", fig)
# Plot Learned Distribution
Ŷ = model(X)
fig, ax, plt = scatter(X[1,:], rand.(Ŷ), markersize=4, label="Predicted Distribution")
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
axislegend(ax, position=:lt)
save("PredictedDistribution.png", fig)
# Plot Conditional Distribution
cond = model(reshape([-2.1], (1,1)))[1]
fig = Figure(resolution=(1000, 500))
density(fig[1,1], rand(cond, 10000), npoints=10000)
save("ConditionalDistribution.png", fig)
end
main()
using MixtureDensityNetworks, Distributions, CairoMakie, MLJ
const n_samples = 1000
const epochs = 500
const batchsize = 128
const mixtures = 8
const layers = [128, 128]
function main()
# Generate Data
X, Y = generate_data(n_samples)
# Create Model
mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers, batchsize=batchsize), MLJ.table(X'), Y[1,:])
# Fit Model on Training Data, Then Evaluate on Test
@info "Evaluating..."
evaluation = MLJ.evaluate!(
mach,
resampling=Holdout(shuffle=true),
measure=[rsq, rmse, mae, mape],
operation=MLJ.predict_mean,
verbosity=2 # Need to set verbosity=2 to show training progress during evaluation
)
names = ["R²", "RMSE", "MAE", "MAPE"]
metrics = round.(evaluation.measurement, digits=3)
@info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")
# Fit Model on Entire Dataset
@info "Training..."
MLJ.fit!(mach)
# Plot Learning Curve
fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))
save("LearningCurve.png", fig)
# Plot Learned Distribution
Ŷ = MLJ.predict(mach) .|> rand
fig, ax, plt = scatter(X[1,:], Ŷ, markersize=4, label="Predicted Distribution")
scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
axislegend(ax, position=:lt)
save("PredictedDistribution.png", fig)
# Plot Conditional Distribution
cond = MLJ.predict(mach, MLJ.table(reshape([-2.1], (1,1))))[1]
fig = Figure(resolution=(1000, 500))
density(fig[1,1], rand(cond, 10000), npoints=10000)
save("ConditionalDistribution.png", fig)
end
main()