Skip to content

Commit

Permalink
Merge pull request #114 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 1.3 release
  • Loading branch information
ablaom authored Sep 6, 2021
2 parents 6288bac + d9a7c69 commit 05091c5
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.2.0"
version = "1.3.0"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
40 changes: 32 additions & 8 deletions src/MLJModelInterface.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module MLJModelInterface

const MODEL_TRAITS = [
:input_scitype,
const MODEL_TRAITS =
[:input_scitype,
:output_scitype,
:target_scitype,
:fit_data_scitype,
Expand Down Expand Up @@ -38,7 +38,17 @@ const ABSTRACT_MODEL_SUBTYPES =
:Deterministic,
:Interval,
:JointProbabilistic,
:Static]
:Static,
:Annotator,
:SupervisedAnnotator,
:UnsupervisedAnnotator,
:SupervisedDetector,
:UnsupervisedDetector,
:ProbabilisticSupervisedDetector,
:ProbabilisticUnsupervisedDetector,
:DeterministicSupervisedDetector,
:DeterministicUnsupervisedDetector]


# ------------------------------------------------------------------------
# Dependencies
Expand Down Expand Up @@ -69,7 +79,8 @@ export @mlj_model, metadata_pkg, metadata_model
# model api
export fit, update, update_data, transform, inverse_transform,
fitted_params, predict, predict_mode, predict_mean, predict_median,
predict_joint, evaluate, clean!, reformat, training_losses
predict_joint, evaluate, clean!, reformat, training_losses,
augmented_transform

# model traits
for trait in MODEL_TRAITS
Expand Down Expand Up @@ -118,17 +129,30 @@ abstract type Model <: MLJType end
# ------------------------------------------------------------------------
# Model subtypes

abstract type Supervised <: Model end
abstract type Unsupervised <: Model end
abstract type Supervised <: Model end
abstract type Unsupervised <: Model end
abstract type Annotator <: Model end

abstract type Probabilistic <: Supervised end
abstract type Deterministic <: Supervised end
abstract type Interval <: Supervised end

abstract type Static <: Unsupervised end

abstract type JointProbabilistic <: Probabilistic end

abstract type Static <: Unsupervised end

abstract type SupervisedAnnotator <: Annotator end
abstract type UnsupervisedAnnotator <: Annotator end

abstract type UnsupervisedDetector <: UnsupervisedAnnotator end
abstract type SupervisedDetector <: SupervisedAnnotator end

abstract type ProbabilisticSupervisedDetector <: SupervisedDetector end
abstract type ProbabilisticUnsupervisedDetector <: UnsupervisedDetector end

abstract type DeterministicSupervisedDetector <: SupervisedDetector end
abstract type DeterministicUnsupervisedDetector <: UnsupervisedDetector end

# ------------------------------------------------------------------------
# includes

Expand Down
34 changes: 26 additions & 8 deletions src/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
# fallbacks for supervised models that don't support sample weights:
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)

# fallback for unsupervised detectors when no "evaluation" labels appear:
fit(m::Union{ProbabilisticUnsupervisedDetector,
DeterministicUnsupervisedDetector},
verbosity,
X,
y) = fit(m, verbosity, X)

"""
MLJModelInterface.update(model, verbosity, fitresult, cache, data...)
Expand Down Expand Up @@ -98,24 +105,33 @@ fitted_params(::Model, fitresult) = (fitresult=fitresult,)
predict(model, fitresult, new_data...)
`Supervised` models must implement the `predict` operation. Here
`new_data` is the output of `reformat` called on user-specified data.
`Supervised` and `SupervisedAnnotator` models must implement the
`predict` operation. Here `new_data` is the output of `reformat`
called on user-specified data.
"""
function predict end

"""
probabilistic supervised models may overload `predict_mean`
Models types `M` for which `prediction_type(M) == :probablisitic` may
overload `predict_mean`.
"""
function predict_mean end

"""
probabilistic supervised models may overload `predict_mode`
Models types `M` for which `prediction_type(M) == :probablisitic` may
overload `predict_mode`.
"""
function predict_mode end

"""
probabilistic supervised models may overload `predict_median`
Models types `M` for which `prediction_type(M) == :probablisitic` may
overload `predict_median`.
"""
function predict_median end

Expand All @@ -127,12 +143,14 @@ function predict_median end
function predict_joint end

"""
unsupervised methods must implement the `transform` operation
`Unsupervised` models must implement the `transform` operation.
"""
function transform end

"""
unsupervised methods may implement the `inverse_transform` operation
`Unsupervised` models may implement the `inverse_transform` operation.
"""
function inverse_transform end

Expand Down
70 changes: 54 additions & 16 deletions src/model_traits.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
## OVERLOADING TRAIT DEFAULTS RELEVANT TO MODELS

StatisticalTraits.docstring(M::Type{<:MLJType}) = name(M)
StatisticalTraits.docstring(M::Type{<:Model}) =
# unexported aliases:
const Detector = Union{SupervisedDetector,UnsupervisedDetector}
const ProbabilisticDetector = Union{ProbabilisticSupervisedDetector,
ProbabilisticUnsupervisedDetector}
const DeterministicDetector = Union{DeterministicSupervisedDetector,
DeterministicUnsupervisedDetector}

const StatTraits = StatisticalTraits

StatTraits.docstring(M::Type{<:MLJType}) = name(M)
StatTraits.docstring(M::Type{<:Model}) =
"$(name(M)) from $(package_name(M)).jl.\n" *
"[Documentation]($(package_url(M)))."

StatisticalTraits.is_supervised(::Type{<:Supervised}) = true
StatisticalTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
StatisticalTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval
StatTraits.is_supervised(::Type{<:Supervised}) = true

StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
StatTraits.prediction_type(::Type{<:Interval}) = :interval
StatTraits.prediction_type(::Type{<:ProbabilisticDetector}) =
:probabilistic
StatTraits.prediction_type(::Type{<:DeterministicDetector}) =
:deterministic

StatTraits.target_scitype(::Type{<:ProbabilisticDetector}) =
AbstractVector{<:Union{Missing,OrderedFactor{2}}}
StatTraits.target_scitype(::Type{<:DeterministicDetector}) =
AbstractVector{<:Union{Missing,OrderedFactor{2}}}

# implementation is deferred as it requires methodswith which depends upon
# InteractiveUtils which we don't want to bring here as a dependency
Expand All @@ -18,13 +37,13 @@ implemented_methods(model) = implemented_methods(typeof(model))
implemented_methods(::LightInterface, M) = errlight("implemented_methods")

for M in ABSTRACT_MODEL_SUBTYPES
@eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M)
@eval(StatTraits.abstract_type(::Type{<:$M}) = $M)
end

StatisticalTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
Tuple{input_scitype(M)}
StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
function StatTraits.fit_data_scitype(M::Type{<:Supervised})
I = input_scitype(M)
T = target_scitype(M)
ret = Tuple{I,T}
Expand All @@ -37,24 +56,42 @@ function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
end
return ret
end
StatTraits.fit_data_scitype(M::Type{<:UnsupervisedAnnotator}) =
Tuple{input_scitype(M)}
StatTraits.fit_data_scitype(M::Type{<:SupervisedAnnotator}) =
Tuple{input_scitype(M),target_scitype(M)}

# In special case of `UnsupervisedProbabilisticDetector`, and
# `UnsupervsedDeterministicDetector` we allow the target as an
# optional argument to `fit` (that is ignored) so that the `machine`
# constructor will accept it as a valid argument, which then enables
# *evaluation* of the detector with labeled data:
StatTraits.fit_data_scitype(M::Type{<:Union{
ProbabilisticUnsupervisedDetector,
DeterministicUnsupervisedDetector}}) =
Union{Tuple{input_scitype(M)},
Tuple{input_scitype(M),target_scitype(M)}}

StatisticalTraits.transform_scitype(M::Type{<:Unsupervised}) =
StatTraits.transform_scitype(M::Type{<:Unsupervised}) =
output_scitype(M)

StatisticalTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
StatTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
input_scitype(M)

StatisticalTraits.predict_scitype(M::Type{<:Deterministic}) = target_scitype(M)
StatTraits.predict_scitype(M::Type{<:Union{
Deterministic,DeterministicDetector}}) = target_scitype(M)


## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` MODELS
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` and
## `ProbabilisticDetector` MODELS

# This seems less than ideal but should reduce the number of `Unknown`
# in `prediction_type` for models which, historically, have not
# implemented the trait.

StatisticalTraits.predict_scitype(M::Type{<:Probabilistic}) =
_density(target_scitype(M))
StatTraits.predict_scitype(
M::Type{<:Union{Probabilistic,ProbabilisticDetector}}
) = _density(target_scitype(M))

_density(::Any) = Unknown
for T in [:Continuous, :Count, :Textual]
Expand All @@ -78,6 +115,7 @@ for T in [:Finite,
end)
end


for T in [:Finite, :Multiclass, :OrderedFactor]
eval(quote
_density(::Type{AbstractArray{<:$T{N},D}}) where {N,D} =
Expand Down

0 comments on commit 05091c5

Please sign in to comment.