This package exposes the scikit-learn interface. Packages that implement this interface can be used in conjunction with ScikitLearn.jl (pipelines, cross-validation, hyperparameter tuning, ...)
This is an intentionally slim package (~100 LOC, no dependencies). That way,
ML libraries can import ScikitLearnBase
without dragging along all of
ScikitLearn
's dependencies.
The docs contain an overview of the API and a more thorough specification.
There are two implementation strategies for an existing machine learning package:
- Create a new type that wraps the existing type. The new type can usually be written entirely on top of the existing codebase (i.e. without modifying it). This gives more implementation freedom, and a more consistent interface amongst the various ScikitLearn.jl models. Here's an example from DecisionTree.jl
- Use the existing type. This requires less code, and is usually better when the model type already contains the hyperparameters / fitting arguments.
For models with simple hyperparameters, it boils down to this:
import ScikitLearnBase
mutable struct NaiveBayes
# The model hyperparameters (not learned from data)
bias::Float64
# The parameters learned from data
counts::Matrix{Int}
# A constructor that accepts the hyperparameters as keyword arguments
# with sensible defaults
NaiveBayes(; bias=0.0f0) = new(bias)
end
# This will define `clone`, `set_params!` and `get_params` for the model
ScikitLearnBase.@declare_hyperparameters(NaiveBayes, [:bias])
# NaiveBayes is a classifier
ScikitLearnBase.is_classifier(::NaiveBayes) = true # not required for transformers
function ScikitLearnBase.fit!(model::NaiveBayes, X, y)
# X should be of size (n_sample, n_feature)
.... # modify model.counts here
return model
end
function ScikitLearnBase.predict(model::NaiveBayes, X)
.... # returns a vector of predicted classes here
end
Models with more complex hyperparameter specifications should implement clone
,
get_params
and set_params!
explicitly instead of using
@declare_hyperparameters
.
More examples of PRs that implement the interface: GaussianMixtures.jl, GaussianProcesses.jl, DecisionTree.jl, LowRankModels.jl
Note: if the model performs unsupervised learning, implement transform
instead of predict
.
Once your library implements the API, file an issue/PR to add it to the list of models.