MixtureDensityNetworks.jl

A simple interface for defining, training, and deploying MDNs.
Author JoshuaBillson
Popularity
2 Stars
Updated Last
1 Year Ago
Started In
March 2023

MixtureDensityNetworks

Stable Dev Build Status Coverage

This package provides a simple interface for defining, training, and deploying Mixture Density Networks (MDNs). MDNs were first proposed by Bishop (1994). We can think of an MDN as a specialized type of Artificial Neural Network (ANN), which takes some features X and returns a distribution over the labels Y under a Gaussian Mixture Model (GMM). Unlike an ANN, MDNs maintain the full conditional distribution P(Y|X). This makes them particularly well-suited for situations where we want to maintain some measure of the uncertainty in our predictions. Moreover, because GMMs can represent multimodal distributions, MDNs are capable of modelling one-to-many relationships, which occurs when each input X can be associated with more than one output Y.

MLJ Compatibility

This package implements the interface specified by MLJModelInterface and is thus fully compatible with the MLJ ecosystem. Below is an example demonstrating the use of this package in conjunction with MLJ.

Example (Native Interface)

using Flux, MixtureDensityNetworks, Distributions, CairoMakie, Logging, TerminalLoggers

const n_samples = 1000
const epochs = 1000
const batchsize = 128
const mixtures = 8
const layers = [128, 128]

function main()
    # Generate Data
    X, Y = generate_data(n_samples)

    # Create Model
    model = MixtureDensityNetwork(1, 1, layers, mixtures)

    # Fit Model
    model, report = MixtureDensityNetworks.fit!(model, X, Y; epochs=epochs, opt=Flux.Adam(1e-3), batchsize=batchsize)

    # Plot Learning Curve
    fig, _, _ = lines(1:epochs, report.learning_curve, axis=(;xlabel="Epochs", ylabel="Loss"))
    save("LearningCurve.png", fig)

    # Plot Learned Distribution= model(X)
    fig, ax, plt = scatter(X[1,:], rand.(Ŷ), markersize=4, label="Predicted Distribution")
    scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
    axislegend(ax, position=:lt)
    save("PredictedDistribution.png", fig)

    # Plot Conditional Distribution
    cond = model(reshape([-2.1], (1,1)))[1]
    fig = Figure(resolution=(1000, 500))
    density(fig[1,1], rand(cond, 10000), npoints=10000)
    save("ConditionalDistribution.png", fig)
end

main()

Example (MLJ Interface)

using MixtureDensityNetworks, Distributions, CairoMakie, MLJ

const n_samples = 1000
const epochs = 500
const batchsize = 128
const mixtures = 8
const layers = [128, 128]

function main()
    # Generate Data
    X, Y = generate_data(n_samples)

    # Create Model
    mach = MLJ.machine(MDN(epochs=epochs, mixtures=mixtures, layers=layers, batchsize=batchsize), MLJ.table(X'), Y[1,:])

    # Fit Model on Training Data, Then Evaluate on Test
    @info "Evaluating..."
    evaluation = MLJ.evaluate!(
        mach, 
        resampling=Holdout(shuffle=true), 
        measure=[rsq, rmse, mae, mape], 
        operation=MLJ.predict_mean, 
        verbosity=2  # Need to set verbosity=2 to show training progress during evaluation 
    )
    names = ["", "RMSE", "MAE", "MAPE"]
    metrics = round.(evaluation.measurement, digits=3)
    @info "Metrics: " * join(["$name: $metric" for (name, metric) in zip(names, metrics)], ", ")

    # Fit Model on Entire Dataset
    @info "Training..."
    MLJ.fit!(mach)

    # Plot Learning Curve
    fig, _, _ = lines(1:epochs, MLJ.training_losses(mach), axis=(;xlabel="Epochs", ylabel="Loss"))
    save("LearningCurve.png", fig)

    # Plot Learned Distribution= MLJ.predict(mach) .|> rand
    fig, ax, plt = scatter(X[1,:], Ŷ, markersize=4, label="Predicted Distribution")
    scatter!(ax, X[1,:], Y[1,:], markersize=3, label="True Distribution")
    axislegend(ax, position=:lt)
    save("PredictedDistribution.png", fig)

    # Plot Conditional Distribution
    cond = MLJ.predict(mach, MLJ.table(reshape([-2.1], (1,1))))[1]
    fig = Figure(resolution=(1000, 500))
    density(fig[1,1], rand(cond, 10000), npoints=10000)
    save("ConditionalDistribution.png", fig)
end

main()