Linux | Coverage |
---|---|
A lightweight package for controlling iterative algorithms, with a view to training and optimizing machine learning models.
Builds on EarlyStopping.jl and inspired by LearningStrategies.jl.
Other related software: DynamicIterators.jl.
using Pkg
Pkg.add("IterationControl")
Suppose you have some kind of object,
SquareRooter(x)
, for iteratively computing approximations to the
square root of x
:
model = SquareRooter(9)
julia> model.root
1.0
train!(model, 2) # train for 2 iterations
julia> model.root
3.4
train!(model, 1) # train for 1 more iteration
julia> model.root
3.023529411764706
Then we can replace the integer argument n
in train!(model, n)
with a number of more sophisticated controls by "lifting" the method
train!
to the IterationControl.train!
method defined in this
package:
using IterationControl
IterationControl.train!(model::SquareRooter, n) = train!(model, n) # lifting
By definition, the lifted train!
has the same functionality as the original one:
model = SquareRooter(9)
IterationControl.train!(model, 2)
julia> model.root
3.4
But now we can also do this:
julia> IterationControl.train!(model, Step(2), NumberLimit(3), Info(m->m.root));
[ Info: 3.4
[ Info: 3.00009155413138
[ Info: 3.0
[ Info: Stop triggered by NumberLimit(3) stopping criterion.
Here each control is repeatedly applied in sequence until one of them
triggers a stop. The first control Step(2)
says, "Train the model
two more iterations"; the second asks, "Have I been applied 3 times
yet?", signaling a stop (at the end of the current control cycle) if
so; and the third logs the value of the function m -> m.root
,
evaluated on model
, to Info
. In this example only the second
control can terminate model iteration.
If model
admits a method returning a loss (in this case the
difference between x
and the square of root
) then we can lift
that method to IterationControl.loss
to enable control using
loss-based stopping criteria, such as a loss threshold. In the
demonstration below, we also include a callback:
model = SquareRooter(4)
train!(model, 1)
julia> loss(model)
2.25
IterationControl.loss(model::SquareRooter) = loss(model) # lifting
losses = Float64[]
callback(model) = push!(losses, loss(model))
julia> IterationControl.train!(model, Step(1), Threshold(0.0001), Callback(callback));
[ Info: Stop triggered by Threshold(0.0001) stopping criterion.
julia> losses
2-element Array{Float64,1}:
0.002439396192741583
3.716891878724482e-7
In many applications to machine learning, "loss" will be an
out-of-sample loss, computed after some iterations. If model
additionally generates user-inspectable "training losses" (one per
iteration) then similarly lifting the appropriate access function to
IterationControl.training_losses
enables Prechelt's
progress-modified generalization loss stopping criterion, PQ
(see
Table 1 below).
PQ
is the only criterion from the
EarlyStopping.jl package
not otherwise enabled when IterationControl.loss
is overloaded as
above.
Reference. Prechelt, Lutz (1998): "Early Stopping - But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer.
The interface just described is sufficient for controlling conventional machine learning models with an iteration parameter, as this tree boosting example shows.
For online or incremental training, lift the method for ingesting data
into the model to IterationControl.ingest!(model, datum)
and use the
control Data(data)
. Here data
is any iterator generating the
datum
items to be ingested (one per application of the control). By
default, the Data
control becomes passive after data
is
exhausted. Do ?Data
for details. (See Access to model through a
wrapper below on dealing with any
model wrapping necessary to implement data ingestion.)
A simple particle tracking example is given here.
The IterationControl.train!
method can be given the keyword argument
verbosity=...
, defaulting to 1
. The larger verbosity
, the noisier.
The return value of IterationControl.train!
is a tuple of (control, report)
tuples, where report
is generated by control
at the end
of training. For example, the final loss can be accessed from the
report of the WithLossDo()
control:
model = SquareRooter(9)
reports = IterationControl.train!(model, Step(1), WithLossDo(println), NumberLimit(3));
julia> last(reports[2])
(loss = 0.1417301038062284, done = false, log = "")
julia> last(reports[2]).loss
0.1417301038062284
Controls are repeatedly applied in sequence until a control triggers a stop. Each control type has a detailed doc-string. Below is a short summary, with some advanced options omitted.
control | description | enabled if these are overloaded | can trigger a stop | notation in Prechelt |
---|---|---|---|---|
Step(n=1) |
Train model for n iterations |
train! |
no | |
Info(f=identity) |
Log to Info the value of f(model) |
train! |
no | |
Warn(predicate, f="") |
Log to Warn the value of f or f(model) if predicate(model) holds |
train! |
no | |
Error(predicate, f="") |
Log to Error the value of f or f(model) if predicate(model) holds and then stop |
train! |
yes | |
Callback(f=_->nothing) |
Call f(model) |
train! |
yes | |
TimeLimit(t=0.5) |
Stop after t hours |
train! |
yes | |
NumberLimit(n=100) |
Stop after n applications of the control |
train! |
yes | |
NumberSinceBest(n=6) |
Stop when best loss occurred n control applications ago |
train! |
yes | |
WithNumberDo(f=n->@info(n)) |
Call f(n + 1) where n is the number of complete control cycles so far |
train! |
yes | |
WithLossDo(f=x->@info("loss: $x")) |
Call f(loss) where loss is the current loss |
train! , loss |
yes | |
WithTrainingLossesDo(f=v->@info(v)) |
Call f(v) where v is the current batch of training losses |
train! , training_losses |
yes | |
InvalidValue() |
Stop when NaN , Inf or -Inf loss/training loss encountered |
train! |
yes | |
Threshold(value=0.0) |
Stop when loss < value |
train! , loss |
yes | |
GL(alpha=2.0) |
Stop after "Generalization Loss" exceeds alpha |
train! , loss |
yes | GL_α |
Patience(n=5) |
Stop after n consecutive loss increases |
train! , loss |
yes | UP_s |
PQ(alpha=0.75, k=5) |
Stop after "Progress-modified GL" exceeds alpha |
train! , loss , training_losses |
yes | PQ_α |
Warmup(c; n=1) |
Wait for n loss updates before checking criteria c |
train! |
no | |
Data(data) |
Call ingest!(model, item) on the next item in the iterable data . |
train! , ingest! |
yes |
Table 1. Atomic controls
Stopping option. All the following controls trigger a stop if the
provided function f
returns true
and stop_if_true=true
is
specified in the constructor: Callback
, WithNumberDo
, WithLossDo
,
WithTrainingLossesDo
.
There are also three control wrappers to modify a control's behavior:
wrapper | description |
---|---|
IterationControl.skip(control; predicate=1) |
Apply control every predicate applications of the control wrapper (can also be a function; see doc-string) |
IterationControl.louder(control; by=1) |
Increase the verbosity level of control by the specified value (negative values lower verbosity) |
IterationControl.with_state_do(control; f=...) |
Apply control and call f(x) where x is the internal state of control; useful for debugging. Default f logs state to Info . Warning: internal control state is not yet part of public API. |
IterationControl.composite(controls...) |
Apply each control in controls in sequence; mostly for under-the-hood use |
Table 2. Wrapped controls
Note that functions ordinarily applied to model
by some control
(e.g., a Callback
) will instead be applied to
IterationControl.expose(model)
if IterationControl.expose
is
appropriately overloaded.
There is no abstract control type; any object can be a
control. Behavior is implemented using a functional style interface
with six methods. Only the first two are compulsory (the fallbacks for
done
, takedown
, needs_loss
and needs_training_losses
always
return false
and NamedTuple()
respectively.):
update!(control, model, verbosity, n) -> state # initialization
update!(control, model, verbosity, n, state) -> state
done(control, state)::Bool
takedown(control, verbosity, state) -> human_readable_named_tuple
Here n
is the control cycle count, i.e., one more than the the
number of completed control cycles.
If it is nonsensical to apply control
to any model for which
loss(model)
has not been overloaded, and we want an error thrown
when this is attempted, then declare needs_loss(control::MyControl) = true
to take value true. Otherwise control
is applied anyway, and
loss
, if called, returns nothing
.
A second trait needs_training_losses(control)
serves an analogous
purpose for training losses.
Here's a simplified version of how IterationControl.train!
calls
these methods:
function train!(model, controls...; verbosity::Int=1)
control = composite(controls...)
# before training:
verbosity > 1 && @info "Using these controls: $(flat(control)). "
# first training event:
n = 1 # counts control cycles
state = update!(control, model, verbosity, n)
finished = done(control, state)
# checks that model supports control:
if needs_loss(control) && loss(model) === nothing
throw(ERR_NEEDS_LOSS)
end
if needs_training_losses(control) && training_losses(model) === nothing
throw(ERR_NEEDS_TRAINING_LOSSES)
end
# subsequent training events:
while !finished
n += 1
state = update!(control, model, verbosity, n, state)
finished = done(control, state)
end
# finalization:
return takedown(control, verbosity, state)
end