From 4fdc1ab73e38ed9687a15d98b925d534380de1fe Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 16 Aug 2021 14:19:04 +1200 Subject: [PATCH 1/5] integrate new traits: predict_scitype, etc --- src/MLJModelInterface.jl | 5 +++ src/model_traits.jl | 70 ++++++++++++++++++++++++++++++++++++++++ test/model_traits.jl | 64 ++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+) diff --git a/src/MLJModelInterface.jl b/src/MLJModelInterface.jl index 5e9a3d5..e3c52f8 100644 --- a/src/MLJModelInterface.jl +++ b/src/MLJModelInterface.jl @@ -4,6 +4,10 @@ const MODEL_TRAITS = [ :input_scitype, :output_scitype, :target_scitype, + :training_scitype, + :predict_scitype, + :transform_scitype, + :inverse_transform_scitype, :is_pure_julia, :package_name, :package_license, @@ -18,6 +22,7 @@ const MODEL_TRAITS = [ :name, :is_supervised, :prediction_type, + :abstract_type, :implemented_methods, :hyperparameters, :hyperparameter_types, diff --git a/src/model_traits.jl b/src/model_traits.jl index 5ccefb2..c9cc847 100644 --- a/src/model_traits.jl +++ b/src/model_traits.jl @@ -16,3 +16,73 @@ StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M) implemented_methods(model) = implemented_methods(typeof(model)) implemented_methods(::LightInterface, M) = errlight("implemented_methods") + +for M in ABSTRACT_MODEL_SUBTYPES + @eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M) +end + +StatisticalTraits.training_scitype(M::Type{<:Model}) = input_scitype(M) +StatisticalTraits.training_scitype(::Type{<:Static}) = Tuple{} +function StatisticalTraits.training_scitype(M::Type{<:Supervised}) + I = input_scitype(M) + T = target_scitype(M) + ret = Tuple{I,T} + if supports_weights(M) + W = AbstractVector{Union{Continuous,Count}} # weight scitype + return Union{ret,Tuple{I,T,W}} + elseif supports_class_weights(M) + W = AbstractDict{Finite,Union{Continuous,Count}} + return Union{ret,Tuple{I,T,W}} + end + return ret +end + +StatisticalTraits.transform_scitype(M::Type{<:Unsupervised}) = + output_scitype(M) + +StatisticalTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) = + input_scitype(M) + +StatisticalTraits.predict_scitype(M::Type{<:Deterministic}) = target_scitype(M) + + +## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` MODELS + +# This seems less than ideal but should reduce the number of `Unknown` +# in `prediction_type` for models which, historically, have not +# implemented the trait. + +StatisticalTraits.predict_scitype(M::Type{<:Probabilistic}) = + _density(target_scitype(M)) + +_density(::Any) = Unknown +for T in [:Continuous, :Count, :Textual] + eval(quote + _density(::Type{AbstractArray{$T,D}}) where D = + AbstractArray{Density{$T},D} + end) +end + +for T in [:Finite, + :Multiclass, + :OrderedFactor, + :Infinite, + :Continuous, + :Count, + :Textual] + eval(quote + _density(::Type{AbstractArray{<:$T,D}}) where D = + AbstractArray{Density{<:$T},D} + _density(::Type{Table($T)}) = Table(Density{$T}) + end) +end + +for T in [:Finite, :Multiclass, :OrderedFactor] + eval(quote + _density(::Type{AbstractArray{<:$T{N},D}}) where {N,D} = + AbstractArray{Density{<:$T{N}},D} + _density(::Type{AbstractArray{$T{N},D}}) where {N,D} = + AbstractArray{Density{$T{N}},D} + _density(::Type{Table($T{N})}) where N = Table(Density{$T{N}}) + end) +end diff --git a/test/model_traits.jl b/test/model_traits.jl index 1b8c5f7..49be95f 100644 --- a/test/model_traits.jl +++ b/test/model_traits.jl @@ -79,3 +79,67 @@ import .Fruit @test docstring(Float64) == "Float64" @test docstring(Fruit.Banana) == "Banana" end + +@testset "`_density` - helper for predict_scitype fallback" begin + for T in [Continuous, Count, Textual] + @test M._density(AbstractArray{T,3}) == + AbstractArray{Density{T},3} + end + + for T in [Finite, + Multiclass, + OrderedFactor, + Infinite, + Continuous, + Count, + Textual] + @test M._density(AbstractVector{<:T}) == + AbstractVector{Density{<:T}} + @test M._density(Table(T)) == Table(Density{T}) + end + + for T in [Finite, Multiclass, OrderedFactor] + @test M._density(AbstractArray{<:T{2},3}) == + AbstractArray{Density{<:T{2}},3} + @test M._density(AbstractArray{T{2},3}) == + AbstractArray{Density{T{2}},3} + @test M._density(Table(T{2})) == Table(Density{T{2}}) + end +end + +@mlj_model mutable struct P2 <: Probabilistic end +M.target_scitype(::Type{<:P2}) = AbstractVector{<:Multiclass} +M.input_scitype(::Type{<:P2}) = Table(Continuous) + +@mlj_model mutable struct U2 <: Unsupervised end +M.output_scitype(::Type{<:U2}) = AbstractVector{<:Multiclass} +M.input_scitype(::Type{<:U2}) = Table(Continuous) + +@mlj_model mutable struct S2 <: Static end +M.output_scitype(::Type{<:S2}) = AbstractVector{<:Multiclass} +M.input_scitype(::Type{<:S2}) = Table(Continuous) + +@testset "operation scitypes" begin + @test predict_scitype(P2()) == AbstractVector{Density{<:Multiclass}} + @test transform_scitype(P2()) == Unknown + @test transform_scitype(U2()) == AbstractVector{<:Multiclass} + @test inverse_transform_scitype(U2()) == Table(Continuous) + @test predict_scitype(U2()) == Unknown + @test transform_scitype(S2()) == AbstractVector{<:Multiclass} + @test inverse_transform_scitype(S2()) == Table(Continuous) +end + +@testset "abstract_type, training_scitype" begin + @test abstract_type(P2()) == Probabilistic + @test abstract_type(S1()) == Supervised + @test abstract_type(U1()) == Unsupervised + @test abstract_type(D1()) == Deterministic + @test abstract_type(P1()) == Probabilistic + + @test training_scitype(P2()) == + Tuple{Table(Continuous),AbstractVector{<:Multiclass}} + @test training_scitype(U2()) == Table(Continuous) + @test training_scitype(S2()) == Tuple{} +end + +true From d29e1a72548a501d7ffa33b53b4d286fa8d2b417 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 16 Aug 2021 14:31:35 +1200 Subject: [PATCH 2/5] training_scitype -> fit_data_scitype --- src/MLJModelInterface.jl | 2 +- src/model_traits.jl | 6 +++--- test/model_traits.jl | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/MLJModelInterface.jl b/src/MLJModelInterface.jl index e3c52f8..a18c2ad 100644 --- a/src/MLJModelInterface.jl +++ b/src/MLJModelInterface.jl @@ -4,7 +4,7 @@ const MODEL_TRAITS = [ :input_scitype, :output_scitype, :target_scitype, - :training_scitype, + :fit_data_scitype, :predict_scitype, :transform_scitype, :inverse_transform_scitype, diff --git a/src/model_traits.jl b/src/model_traits.jl index c9cc847..cf1074b 100644 --- a/src/model_traits.jl +++ b/src/model_traits.jl @@ -21,9 +21,9 @@ for M in ABSTRACT_MODEL_SUBTYPES @eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M) end -StatisticalTraits.training_scitype(M::Type{<:Model}) = input_scitype(M) -StatisticalTraits.training_scitype(::Type{<:Static}) = Tuple{} -function StatisticalTraits.training_scitype(M::Type{<:Supervised}) +StatisticalTraits.fit_data_scitype(M::Type{<:Model}) = input_scitype(M) +StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{} +function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised}) I = input_scitype(M) T = target_scitype(M) ret = Tuple{I,T} diff --git a/test/model_traits.jl b/test/model_traits.jl index 49be95f..3ce1142 100644 --- a/test/model_traits.jl +++ b/test/model_traits.jl @@ -129,17 +129,17 @@ M.input_scitype(::Type{<:S2}) = Table(Continuous) @test inverse_transform_scitype(S2()) == Table(Continuous) end -@testset "abstract_type, training_scitype" begin +@testset "abstract_type, fit_data_scitype" begin @test abstract_type(P2()) == Probabilistic @test abstract_type(S1()) == Supervised @test abstract_type(U1()) == Unsupervised @test abstract_type(D1()) == Deterministic @test abstract_type(P1()) == Probabilistic - @test training_scitype(P2()) == + @test fit_data_scitype(P2()) == Tuple{Table(Continuous),AbstractVector{<:Multiclass}} - @test training_scitype(U2()) == Table(Continuous) - @test training_scitype(S2()) == Tuple{} + @test fit_data_scitype(U2()) == Table(Continuous) + @test fit_data_scitype(S2()) == Tuple{} end true From 544d7fc0050233d0f5fa7a9a4ed188a3992b365b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 16 Aug 2021 15:14:06 +1200 Subject: [PATCH 3/5] make sure fit_data_scitype is a Tuple --- src/model_traits.jl | 2 +- test/model_traits.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model_traits.jl b/src/model_traits.jl index cf1074b..da435a8 100644 --- a/src/model_traits.jl +++ b/src/model_traits.jl @@ -21,7 +21,7 @@ for M in ABSTRACT_MODEL_SUBTYPES @eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M) end -StatisticalTraits.fit_data_scitype(M::Type{<:Model}) = input_scitype(M) +StatisticalTraits.fit_data_scitype(M::Type{<:Model}) = Tuple{input_scitype(M)} StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{} function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised}) I = input_scitype(M) diff --git a/test/model_traits.jl b/test/model_traits.jl index 3ce1142..86eb76f 100644 --- a/test/model_traits.jl +++ b/test/model_traits.jl @@ -138,7 +138,7 @@ end @test fit_data_scitype(P2()) == Tuple{Table(Continuous),AbstractVector{<:Multiclass}} - @test fit_data_scitype(U2()) == Table(Continuous) + @test fit_data_scitype(U2()) == Tuple{Table(Continuous)} @test fit_data_scitype(S2()) == Tuple{} end From 8a45a8edddcba2f8ef3c176f60c411f3e7788548 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 17 Aug 2021 08:40:04 +1200 Subject: [PATCH 4/5] make fit_data_scitype(::Type{<:Model}) fallback `Unknown` --- src/model_traits.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model_traits.jl b/src/model_traits.jl index da435a8..4342a48 100644 --- a/src/model_traits.jl +++ b/src/model_traits.jl @@ -21,7 +21,8 @@ for M in ABSTRACT_MODEL_SUBTYPES @eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M) end -StatisticalTraits.fit_data_scitype(M::Type{<:Model}) = Tuple{input_scitype(M)} +StatisticalTraits.fit_data_scitype(M::Type{<:Unsupervised}) = + Tuple{input_scitype(M)} StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{} function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised}) I = input_scitype(M) From 0acb64cf34d2d3f7d3ddbafba075d2ec0855fc12 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 17 Aug 2021 08:41:06 +1200 Subject: [PATCH 5/5] bump StatisticalTraits,ScientificTypesBase compat; bump 1.2 --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index ff2be6e..352a9c3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJModelInterface" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" authors = ["Thibaut Lienart and Anthony Blaom"] -version = "1.1.3" +version = "1.2.0" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -9,8 +9,8 @@ ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9" [compat] -ScientificTypesBase = "1, 2" -StatisticalTraits = "2" +ScientificTypesBase = "2.1" +StatisticalTraits = "2.1" julia = "1" [extras]