Skip to content

Commit

Permalink
Clarify input_scitype for Static models
Browse files Browse the repository at this point in the history
oops
  • Loading branch information
ablaom committed Jan 8, 2024
1 parent a834e97 commit 7aa6472
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions docs/src/adding_models_for_general_use.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adding Models for General Use
# Adding Models for General Use

!!! note

Expand Down Expand Up @@ -975,7 +975,7 @@ also appears in the EvoTrees.jl package.
Here "user-supplied data" is what the MLJ user supplies when
constructing a machine, as in `machine(models, args...)`, which
coincides with the arguments expected by `fit(model, verbosity,
args...)` when `reformat` is not overloaded.
args...)` when `reformat` is not overloaded.

Overloading `reformat` is permitted for any `Model`
subtype, except for subtypes of `Static`. Here is a complete list of
Expand All @@ -992,7 +992,7 @@ responsibilities for such an implementation, for some
serving as a data front-end for operations like `predict`. It must
always hold that `reformat(model, args...)[1] = reformat(model,
args[1])`.

The fallback is `reformat(model, args...) = args` (i.e., slurps provided data).

*Important.* `reformat(model::SomeModelType, args...)` must always return a tuple, even if
Expand Down Expand Up @@ -1204,7 +1204,7 @@ Your document string must include the following components, in order:
- A closing *"See also"* sentence which includes a `@ref` link to the raw model type (if you are wrapping one).


## Unsupervised models
## Unsupervised

Unsupervised models implement the MLJ model interface in a very
similar fashion. The main differences are:
Expand All @@ -1214,28 +1214,30 @@ similar fashion. The main differences are:
although this is not a hard requirement. For example, a feature selection tool (wrapping
some supervised model) might also include a target `y` as input. Furthermore, in the
case of models that subtype `Static <: Unsupervised` (see also [Static
transformers](@ref) `fit` has no training arguments at all, but does not need to be
transformers](@ref)) `fit` has no training arguments at all, but does not need to be
implemented as a fallback returns `(nothing, nothing, nothing)`.

- A `transform` and/or `predict` method is implemented, and has the same signature as
`predict` does in the supervised case, as in `MLJModelInterface.transform(model,
fitresult, Xnew)`. However, it may only have one data argument `Xnew`, unless `model <:
Static`, in which case there is no restriction. A use-case for `predict` is K-means
`MLJModelInterface.predict(model, fitresult, Xnew)`. A use-case is
clustering that `predict`s labels and `transform`s
input features into a space of lower dimension. See [Transformers
that also predict](@ref) for an example.
Static`, in which case there is no restriction. A use-case for `predict` is K-means
clustering that `predict`s labels and `transform`s input features into a space of lower
dimension. See [Transformers that also predict](@ref) for an example.

- The `target_scitype` trait continues to refer to the output of `predict`, if
implemented, while a trait, `output_scitype`, is for the output of `transform`.
- The `target_scitype` refers to the output of `predict`, if implemented. A new trait,
`output_scitype`, is for the output of `transform`. Unless the model is `Static` (see
below) the trait `input_scitype` is for the single data argument of `transform` (and
`predict`, if implemented). If `fit` has more than one data argument, you must overload
the train `fit_data_scitype`, which bounds the allowed `data` passed to `fit(model,
verbosity, data...)` and will always be a `Tuple` type.

- An `inverse_transform` can be optionally implemented. The signature
is the same as `transform`, as in
`MLJModelInterface.inverse_transform(model, fitresult, Xout)`, which:
- must make sense for any `Xout` for which `scitype(Xout) <:
output_scitype(SomeSupervisedModel)` (see below); and
output_scitype(SomeSupervisedModel)` (see below); and
- must return an object `Xin` satisfying `scitype(Xin) <:
input_scitype(SomeSupervisedModel)`.
input_scitype(SomeSupervisedModel)`.

For sample implementatations, see MLJ's [built-in
transformers](https://github.com/JuliaAI/MLJModels.jl/blob/dev/src/builtins/Transformers.jl)
Expand All @@ -1245,8 +1247,16 @@ and the clustering models at

## Static models (models that do not generalize)

See [Static transformers](@ref) for basic implementation of models that do not generalize
to new data but do have hyperparameters.
A model type subtypes `Static <: Unsupervised` if it does not generalize to new data but
nevertheless has hyperparameters. See [Static transformers](@ref) for examples. In the
`Static` case, `transform` can have multiple arguments and `input_scitype` refers to the
allowed scitype of the slurped data, *even if there is only a single argument.* For
example, if the signature is `transform(static_model, X1, X2)`, then the allowed
`input_scitype` might be `Tuple{Table(Continuous), Table(Continuous)}`; if the signature
is `transform(static_model, X)`, the allowed `input_scitype` might be
`Tuple{Table(Continous)}`. The other traits are as for regular `Unsupervised` models, as
described above.


### Reporting byproducts of a static transformation

Expand Down

0 comments on commit 7aa6472

Please sign in to comment.