From 7aa6472de471d46a1302bc8a3c80e6f123f9a456 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 9 Jan 2024 11:22:19 +1300 Subject: [PATCH] Clarify `input_scitype` for Static models oops --- docs/src/adding_models_for_general_use.md | 42 ++++++++++++++--------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/docs/src/adding_models_for_general_use.md b/docs/src/adding_models_for_general_use.md index a30ad96f6..76388899b 100755 --- a/docs/src/adding_models_for_general_use.md +++ b/docs/src/adding_models_for_general_use.md @@ -1,4 +1,4 @@ -# Adding Models for General Use +# Adding Models for General Use !!! note @@ -975,7 +975,7 @@ also appears in the EvoTrees.jl package. Here "user-supplied data" is what the MLJ user supplies when constructing a machine, as in `machine(models, args...)`, which coincides with the arguments expected by `fit(model, verbosity, -args...)` when `reformat` is not overloaded. +args...)` when `reformat` is not overloaded. Overloading `reformat` is permitted for any `Model` subtype, except for subtypes of `Static`. Here is a complete list of @@ -992,7 +992,7 @@ responsibilities for such an implementation, for some serving as a data front-end for operations like `predict`. It must always hold that `reformat(model, args...)[1] = reformat(model, args[1])`. - + The fallback is `reformat(model, args...) = args` (i.e., slurps provided data). *Important.* `reformat(model::SomeModelType, args...)` must always return a tuple, even if @@ -1204,7 +1204,7 @@ Your document string must include the following components, in order: - A closing *"See also"* sentence which includes a `@ref` link to the raw model type (if you are wrapping one). -## Unsupervised models +## Unsupervised Unsupervised models implement the MLJ model interface in a very similar fashion. The main differences are: @@ -1214,28 +1214,30 @@ similar fashion. The main differences are: although this is not a hard requirement. For example, a feature selection tool (wrapping some supervised model) might also include a target `y` as input. Furthermore, in the case of models that subtype `Static <: Unsupervised` (see also [Static - transformers](@ref) `fit` has no training arguments at all, but does not need to be + transformers](@ref)) `fit` has no training arguments at all, but does not need to be implemented as a fallback returns `(nothing, nothing, nothing)`. - A `transform` and/or `predict` method is implemented, and has the same signature as `predict` does in the supervised case, as in `MLJModelInterface.transform(model, fitresult, Xnew)`. However, it may only have one data argument `Xnew`, unless `model <: - Static`, in which case there is no restriction. A use-case for `predict` is K-means - `MLJModelInterface.predict(model, fitresult, Xnew)`. A use-case is - clustering that `predict`s labels and `transform`s - input features into a space of lower dimension. See [Transformers - that also predict](@ref) for an example. + Static`, in which case there is no restriction. A use-case for `predict` is K-means + clustering that `predict`s labels and `transform`s input features into a space of lower + dimension. See [Transformers that also predict](@ref) for an example. -- The `target_scitype` trait continues to refer to the output of `predict`, if - implemented, while a trait, `output_scitype`, is for the output of `transform`. +- The `target_scitype` refers to the output of `predict`, if implemented. A new trait, + `output_scitype`, is for the output of `transform`. Unless the model is `Static` (see + below) the trait `input_scitype` is for the single data argument of `transform` (and + `predict`, if implemented). If `fit` has more than one data argument, you must overload + the train `fit_data_scitype`, which bounds the allowed `data` passed to `fit(model, + verbosity, data...)` and will always be a `Tuple` type. - An `inverse_transform` can be optionally implemented. The signature is the same as `transform`, as in `MLJModelInterface.inverse_transform(model, fitresult, Xout)`, which: - must make sense for any `Xout` for which `scitype(Xout) <: - output_scitype(SomeSupervisedModel)` (see below); and + output_scitype(SomeSupervisedModel)` (see below); and - must return an object `Xin` satisfying `scitype(Xin) <: - input_scitype(SomeSupervisedModel)`. + input_scitype(SomeSupervisedModel)`. For sample implementatations, see MLJ's [built-in transformers](https://github.com/JuliaAI/MLJModels.jl/blob/dev/src/builtins/Transformers.jl) @@ -1245,8 +1247,16 @@ and the clustering models at ## Static models (models that do not generalize) -See [Static transformers](@ref) for basic implementation of models that do not generalize -to new data but do have hyperparameters. +A model type subtypes `Static <: Unsupervised` if it does not generalize to new data but +nevertheless has hyperparameters. See [Static transformers](@ref) for examples. In the +`Static` case, `transform` can have multiple arguments and `input_scitype` refers to the +allowed scitype of the slurped data, *even if there is only a single argument.* For +example, if the signature is `transform(static_model, X1, X2)`, then the allowed +`input_scitype` might be `Tuple{Table(Continuous), Table(Continuous)}`; if the signature +is `transform(static_model, X)`, the allowed `input_scitype` might be +`Tuple{Table(Continous)}`. The other traits are as for regular `Unsupervised` models, as +described above. + ### Reporting byproducts of a static transformation