IterationControl.jl

A package for controlling iterative algorithms
Author JuliaAI
Popularity
21 Stars
Updated Last
1 Year Ago
Started In
March 2021

IterationControl.jl

Linux Coverage

A lightweight package for controlling iterative algorithms, with a view to training and optimizing machine learning models.

Builds on EarlyStopping.jl and inspired by LearningStrategies.jl.

Other related software: DynamicIterators.jl.

Installation

```using Pkg

Basic idea

Suppose you have some kind of object, `SquareRooter(x)`, for iteratively computing approximations to the square root of `x`:

```model = SquareRooter(9)

julia> model.root
1.0

train!(model, 2) # train for 2 iterations

julia> model.root
3.4

train!(model, 1) # train for 1 more iteration

julia> model.root
3.023529411764706```

Then we can replace the integer argument `n` in `train!(model, n)` with a number of more sophisticated controls by "lifting" the method `train!` to the `IterationControl.train!` method defined in this package:

```using IterationControl
IterationControl.train!(model::SquareRooter, n) =  train!(model, n) # lifting```

By definition, the lifted `train!` has the same functionality as the original one:

```model = SquareRooter(9)
IterationControl.train!(model, 2)

julia> model.root
3.4```

But now we can also do this:

```julia> IterationControl.train!(model, Step(2), NumberLimit(3), Info(m->m.root));
[ Info: 3.4
[ Info: 3.00009155413138
[ Info: 3.0
[ Info: Stop triggered by NumberLimit(3) stopping criterion.```

Here each control is repeatedly applied in sequence until one of them triggers a stop. The first control `Step(2)` says, "Train the model two more iterations"; the second asks, "Have I been applied 3 times yet?", signaling a stop (at the end of the current control cycle) if so; and the third logs the value of the function `m -> m.root`, evaluated on `model`, to `Info`. In this example only the second control can terminate model iteration.

If `model` admits a method returning a loss (in this case the difference between `x` and the square of `root`) then we can lift that method to `IterationControl.loss` to enable control using loss-based stopping criteria, such as a loss threshold. In the demonstration below, we also include a callback:

```model = SquareRooter(4)
train!(model, 1)

julia> loss(model)
2.25

IterationControl.loss(model::SquareRooter) = loss(model) # lifting

losses = Float64[]
callback(model) = push!(losses, loss(model))

julia> IterationControl.train!(model,
Step(1),
Threshold(0.0001),
Callback(callback));
[ Info: Stop triggered by Threshold(0.0001) stopping criterion.

julia> losses
2-element Array{Float64,1}:
0.002439396192741583
3.716891878724482e-7```

In many applications to machine learning, "loss" will be an out-of-sample loss, computed after some iterations. If `model` additionally generates user-inspectable "training losses" (one per iteration) then similarly lifting the appropriate access function to `IterationControl.training_losses` enables Prechelt's progress-modified generalization loss stopping criterion, `PQ` (see Table 1 below).

`PQ` is the only criterion from the EarlyStopping.jl package not otherwise enabled when `IterationControl.loss` is overloaded as above.

Reference. Prechelt, Lutz (1998): "Early Stopping - But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer.

The interface just described is sufficient for controlling conventional machine learning models with an iteration parameter, as this tree boosting example shows.

Online and incremental training

For online or incremental training, lift the method for ingesting data into the model to `IterationControl.ingest!(model, datum)` and use the control `Data(data)`. Here `data` is any iterator generating the `datum` items to be ingested (one per application of the control). By default, the `Data` control becomes passive after `data` is exhausted. Do `?Data` for details. (See Access to model through a wrapper below on dealing with any model wrapping necessary to implement data ingestion.)

A simple particle tracking example is given here.

Verbose logging and inspecting control reports

The `IterationControl.train!` method can be given the keyword argument `verbosity=...`, defaulting to `1`. The larger `verbosity`, the noisier.

The return value of `IterationControl.train!` is a tuple of `(control, report)` tuples, where `report` is generated by `control` at the end of training. For example, the final loss can be accessed from the report of the `WithLossDo()` control:

```model = SquareRooter(9)
reports = IterationControl.train!(model, Step(1), WithLossDo(println), NumberLimit(3));

julia> last(reports[2])
(loss = 0.1417301038062284, done = false, log = "")

julia> last(reports[2]).loss
0.1417301038062284```

Controls provided

Controls are repeatedly applied in sequence until a control triggers a stop. Each control type has a detailed doc-string. Below is a short summary, with some advanced options omitted.

control description enabled if these are overloaded can trigger a stop notation in Prechelt
`Step(n=1)` Train model for `n` iterations `train!` no
`Info(f=identity)` Log to `Info` the value of `f(model)` `train!` no
`Warn(predicate, f="")` Log to `Warn` the value of `f` or `f(model)` if `predicate(model)` holds `train!` no
`Error(predicate, f="")` Log to `Error` the value of `f` or `f(model)` if `predicate(model)` holds and then stop `train!` yes
`Callback(f=_->nothing)` Call `f(model)` `train!` yes
`TimeLimit(t=0.5)` Stop after `t` hours `train!` yes
`NumberLimit(n=100)` Stop after `n` applications of the control `train!` yes
`NumberSinceBest(n=6)` Stop when best loss occurred `n` control applications ago `train!` yes
`WithNumberDo(f=n->@info(n))` Call `f(n + 1)` where `n` is the number of complete control cycles so far `train!` yes
`WithLossDo(f=x->@info("loss: \$x"))` Call `f(loss)` where `loss` is the current loss `train!`, `loss` yes
`WithTrainingLossesDo(f=v->@info(v))` Call `f(v)` where `v` is the current batch of training losses `train!`, `training_losses` yes
`InvalidValue()` Stop when `NaN`, `Inf` or `-Inf` loss/training loss encountered `train!` yes
`Threshold(value=0.0)` Stop when `loss < value` `train!`, `loss` yes
`GL(alpha=2.0)` Stop after "Generalization Loss" exceeds `alpha` `train!`, `loss` yes `GL_α`
`Patience(n=5)` Stop after `n` consecutive loss increases `train!`, `loss` yes `UP_s`
`PQ(alpha=0.75, k=5)` Stop after "Progress-modified GL" exceeds `alpha` `train!`, `loss`, `training_losses` yes `PQ_α`
`Data(data)` Call `ingest!(model, item)` on the next `item` in the iterable `data`. `train!`, `ingest!` yes

Table 1. Atomic controls

Stopping option. All the following controls trigger a stop if the provided function `f` returns `true` and `stop_if_true=true` is specified in the constructor: `Callback`, `WithNumberDo`, `WithLossDo`, `WithTrainingLossesDo`.

There are also three control wrappers to modify a control's behavior:

wrapper description
`IterationControl.skip(control; predicate=1)` Apply `control` every `predicate` applications of the control wrapper (can also be a function; see doc-string)
`IterationControl.louder(control; by=1)` Increase the verbosity level of `control` by the specified value (negative values lower verbosity)
`IterationControl.with_state_do(control; f=...)` Apply control and call `f(x)` where `x` is the internal state of control; useful for debugging. Default `f` logs state to `Info`. Warning: internal control state is not yet part of public API.
`IterationControl.composite(controls...)` Apply each `control` in `controls` in sequence; mostly for under-the-hood use

Table 2. Wrapped controls

Note that functions ordinarily applied to `model` by some control (e.g., a `Callback`) will instead be applied to `IterationControl.expose(model)` if `IterationControl.expose` is appropriately overloaded.

Implementing new controls

There is no abstract control type; any object can be a control. Behavior is implemented using a functional style interface with six methods. Only the first two are compulsory (the fallbacks for `done`, `takedown`, `needs_loss` and `needs_training_losses` always return `false` and `NamedTuple()` respectively.):

```update!(control, model, verbosity, n) -> state  # initialization
update!(control, model, verbosity, n, state) -> state
done(control, state)::Bool

Here `n` is the control cycle count, i.e., one more than the the number of completed control cycles.

If it is nonsensical to apply `control` to any model for which `loss(model)` has not been overloaded, and we want an error thrown when this is attempted, then declare `needs_loss(control::MyControl) = true` to take value true. Otherwise `control` is applied anyway, and `loss`, if called, returns `nothing`.

A second trait `needs_training_losses(control)` serves an analogous purpose for training losses.

Here's a simplified version of how `IterationControl.train!` calls these methods:

```function train!(model, controls...; verbosity::Int=1)

control = composite(controls...)

# before training:
verbosity > 1 && @info "Using these controls: \$(flat(control)). "

# first training event:
n = 1 # counts control cycles
state = update!(control, model, verbosity, n)
finished = done(control, state)

# checks that model supports control:
if needs_loss(control) && loss(model) === nothing
throw(ERR_NEEDS_LOSS)
end
if needs_training_losses(control) && training_losses(model) === nothing
throw(ERR_NEEDS_TRAINING_LOSSES)
end

# subsequent training events:
while !finished
n += 1
state = update!(control, model, verbosity, n, state)
finished = done(control, state)
end

# finalization:
return takedown(control, verbosity, state)
end```