ParameterSchedulers
ParameterSchedulers.jl provides common machine learning (ML) schedulers for hyper-parameters. Though this package is framework agnostic, a convenient interface for pairing schedules with Flux.jl optimizers is available. Using this package with Flux is as simple as:
using Flux, ParameterSchedulers
using ParameterSchedulers: Scheduler
opt = Scheduler(Exp(λ = 1e-2, γ = 0.8), Momentum())
Available Schedules
This is a table of the common schedules implemented, but ParameterSchedulers provides utilities for creating more exotic schedules as well. You can read this paper for more information on the schedules below.
{cell=table, display=false, output=false, results=false}
using UnicodePlots, ParameterSchedulers
Schedule | Description | Type | Example |
---|---|---|---|
Exponential decay by |
Decay |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = Step(λ = 1.0, γ = 0.8, step_sizes = [2, 3, 2])
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Exponential decay by |
Decay |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = Exp(λ = 1.0, γ = 0.5)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Polynomial decay at degree |
Decay |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = Poly(λ = 1.0, p = 2, max_iter = t[end])
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Inverse decay at rate |
Decay |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = Inv(λ = 1.0, p = 2, γ = 0.8)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Triangle wave function |
Cyclic |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = Triangle(λ0 = 0.0, λ1 = 1.0, period = 2)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Triangle wave function with half the amplitude every |
Cyclic |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = TriangleDecay2(λ0 = 0.0, λ1 = 1.0, period = 2)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Triangle wave function with exponential amplitude decay at rate |
Cyclic |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = TriangleExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Sine function |
Cyclic |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = Sin(λ0 = 0.0, λ1 = 1.0, period = 2)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Sine function with half the amplitude every |
Cyclic |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = SinDecay2(λ0 = 0.0, λ1 = 1.0, period = 2)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Sine function with exponential amplitude decay at rate |
Cyclic |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = SinExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) | |
Cyclic |
{cell=table, display=false} using UnicodePlots, ParameterSchedulers
t = 1:10 |> collect
s = Cos(λ0 = 0.0, λ1 = 1.0, period = 4)
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) |