From 185286f97c92233da549a2ff794f73503a1cc26d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 30 Aug 2022 19:02:49 +1200 Subject: [PATCH 01/22] add OrderedCollections to test deps --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 81553de..483f7a6 100644 --- a/Project.toml +++ b/Project.toml @@ -19,9 +19,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "ScientificTypes", "Tables", "Test"] +test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "OrderedCollections", "ScientificTypes", "Tables", "Test"] From 1dbae102861e9401017cd6e1a6dce9e70becd1af Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 30 Aug 2022 19:11:59 +1200 Subject: [PATCH 02/22] fix a deprecation warning --- test/data_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/data_utils.jl b/test/data_utils.jl index 012ff25..6c1eb75 100644 --- a/test/data_utils.jl +++ b/test/data_utils.jl @@ -323,9 +323,9 @@ end eval(:(module UserSide import MLJModelInterface: metadata_model, metadata_pkg struct A end - descr = "something" + human_name = "Big Foot" # Smoke tests. - metadata_model(A; descr=descr) + metadata_model(A; human_name) metadata_pkg(A) end)) end From 101ffa3e5a8f57f3acecaa75b5babe2ac4d4e0a7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 30 Aug 2022 19:12:54 +1200 Subject: [PATCH 03/22] add report method for merging fit reports with operation reports --- src/model_api.jl | 36 ++++++++++++++++++++++++++++++++++++ test/model_api.jl | 8 +++++++- test/runtests.jl | 1 + 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/model_api.jl b/src/model_api.jl index b041adb..31667ef 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -173,6 +173,8 @@ the feature importances from the model's `fitresult` and `report` as an abstract vector of `feature::Symbol => importance::Real` pairs (e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`). +# New model implementations + The following trait overload is also required: `MLJModelInterface.reports_feature_importances(::Type{<:M}) = true` @@ -182,3 +184,37 @@ If for some reason a model is sometimes unable to report feature importances the """ function feature_importances end + +_named_tuple(named_tuple) = named_tuple +_named_tuple(::Nothing) = NamedTuple() + +""" + MLJModelInterface.report(model, report_given_method) + +Merge the reports in the dictionary `report_given_method` into a single +property-accessible object. The possible keys of the dictionary are `:fit` and the +symbolic names of MLJModelInterface.jl operations, such as `:predict` or +`:transform`. Each value will be the `report` component returned by a training method +(`fit` or `update`), in the case of `:fit`, or the corresponding operation. + +# New model implementations + +Overloading this method is optional, unless some value in the dictionary +`report_given_method` is possibly neither a named tuple nor `nothing`. + +A fallback returns the usual named tuple merge of the dictionary values, ignoring any +`nothing` values. It is the responsibility of the implementation to ensure individual +reports will never have clashing keys. + +""" +function report(model, report_given_method) + + # Note that we want to avoid copying values in each individual report named tuple, and + # merge the reports in a reproducible order. + + methods = collect(keys(report_given_method)) |> sort! + reports = [report_given_method[method] for method in methods] + + return merge(reports...) +end + diff --git a/test/model_api.jl b/test/model_api.jl index 489584c..5875c9e 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -5,7 +5,6 @@ end f0::Int end - mutable struct APIx1 <: Static end @testset "selectrows(model, data...)" begin @@ -95,3 +94,10 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end @test yhat == fill(DummyUnivariateFinite(), 3) end + +@testset "fallback for `report()` method" begin + report_given_method = + OrderedCollections.OrderedDict(:predict=>(y=7,), :fit=>(x=1, z=3)) + @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == + (x=1, z=3, y=7) +end diff --git a/test/runtests.jl b/test/runtests.jl index 69eefe7..693f2da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using ScientificTypesBase, ScientificTypes using Tables, Distances, CategoricalArrays, InteractiveUtils import DataFrames: DataFrame import Markdown +import OrderedCollections const M = MLJModelInterface const FI = M.FullInterface From 912f82c9f1c8382c0558596592b2a0b248a47fac Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 31 Aug 2022 13:14:08 +1200 Subject: [PATCH 04/22] add forgotten switch for `nothing` --- src/model_api.jl | 3 +-- test/model_api.jl | 6 +++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index 31667ef..9c2f3ed 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -213,8 +213,7 @@ function report(model, report_given_method) # merge the reports in a reproducible order. methods = collect(keys(report_given_method)) |> sort! - reports = [report_given_method[method] for method in methods] + reports = [_named_tuple(report_given_method[method]) for method in methods] return merge(reports...) end - diff --git a/test/model_api.jl b/test/model_api.jl index 5875c9e..4347d87 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -97,7 +97,11 @@ end @testset "fallback for `report()` method" begin report_given_method = - OrderedCollections.OrderedDict(:predict=>(y=7,), :fit=>(x=1, z=3)) + OrderedCollections.OrderedDict( + :predict=>(y=7,), + :fit=>(x=1, z=3), + :transform=>nothing, + ) @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == (x=1, z=3, y=7) end From e78f6edff1a33c98d1314bb161f602df47b41fef Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 1 Sep 2022 11:27:45 +1200 Subject: [PATCH 05/22] tweak warning for bad docstrings --- src/metadata_utils.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/metadata_utils.jl b/src/metadata_utils.jl index 51d6d5a..513d96b 100644 --- a/src/metadata_utils.jl +++ b/src/metadata_utils.jl @@ -65,13 +65,16 @@ function _extend!(program::Expr, trait::Symbol, value, T) end end -const DEPWARN_DOCSTRING = - "`metadata_model` should not be called with the keyword argument "* - "`descr` or `docstring`. Implementers of the MLJ model interface "* - "should instead create an MLJ-compliant docstring in the usual way. "* - "See https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings for details. " +const depwarn_docstring(T) = + """ -""" + Regarding $T: `metadata_model` should not be called with the keyword argument `descr` + or `docstring`. Implementers of the MLJ model interface should instead create an + MLJ-compliant docstring in the usual way. See + https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings + for details. + + """ metadata_model(T; args...) Helper function to write the metadata for a model `T`. @@ -122,7 +125,7 @@ function metadata_model( supports_training_losses::Union{Nothing,Bool}=nothing, reports_feature_importances::Union{Nothing,Bool}=nothing, ) - docstring === nothing || Base.depwarn(DEPWARN_DOCSTRING, :metadata_model) + docstring === nothing || Base.depwarn(depwarn_docstring(T), :metadata_model) program = quote end From cf5c5e3be6cdecc886f71e99b745650617387061 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 1 Sep 2022 11:34:12 +1200 Subject: [PATCH 06/22] oops --- src/metadata_utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/metadata_utils.jl b/src/metadata_utils.jl index 513d96b..d005a67 100644 --- a/src/metadata_utils.jl +++ b/src/metadata_utils.jl @@ -75,6 +75,8 @@ const depwarn_docstring(T) = for details. """ + +""" metadata_model(T; args...) Helper function to write the metadata for a model `T`. From dc1a12ae24fcf90a1eb727efc34b742e0b130adc Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 30 Aug 2022 19:02:49 +1200 Subject: [PATCH 07/22] add OrderedCollections to test deps --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 81553de..483f7a6 100644 --- a/Project.toml +++ b/Project.toml @@ -19,9 +19,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "ScientificTypes", "Tables", "Test"] +test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "OrderedCollections", "ScientificTypes", "Tables", "Test"] From b30f791736795a0363c6fdbcb9000b78f0a2b62b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 30 Aug 2022 19:11:59 +1200 Subject: [PATCH 08/22] fix a deprecation warning --- test/data_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/data_utils.jl b/test/data_utils.jl index 012ff25..6c1eb75 100644 --- a/test/data_utils.jl +++ b/test/data_utils.jl @@ -323,9 +323,9 @@ end eval(:(module UserSide import MLJModelInterface: metadata_model, metadata_pkg struct A end - descr = "something" + human_name = "Big Foot" # Smoke tests. - metadata_model(A; descr=descr) + metadata_model(A; human_name) metadata_pkg(A) end)) end From 438a2d2c3fd8c50feab3ab63e2df55295b99db1d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 30 Aug 2022 19:12:54 +1200 Subject: [PATCH 09/22] add report method for merging fit reports with operation reports --- src/model_api.jl | 36 ++++++++++++++++++++++++++++++++++++ test/model_api.jl | 8 +++++++- test/runtests.jl | 1 + 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/model_api.jl b/src/model_api.jl index b041adb..31667ef 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -173,6 +173,8 @@ the feature importances from the model's `fitresult` and `report` as an abstract vector of `feature::Symbol => importance::Real` pairs (e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`). +# New model implementations + The following trait overload is also required: `MLJModelInterface.reports_feature_importances(::Type{<:M}) = true` @@ -182,3 +184,37 @@ If for some reason a model is sometimes unable to report feature importances the """ function feature_importances end + +_named_tuple(named_tuple) = named_tuple +_named_tuple(::Nothing) = NamedTuple() + +""" + MLJModelInterface.report(model, report_given_method) + +Merge the reports in the dictionary `report_given_method` into a single +property-accessible object. The possible keys of the dictionary are `:fit` and the +symbolic names of MLJModelInterface.jl operations, such as `:predict` or +`:transform`. Each value will be the `report` component returned by a training method +(`fit` or `update`), in the case of `:fit`, or the corresponding operation. + +# New model implementations + +Overloading this method is optional, unless some value in the dictionary +`report_given_method` is possibly neither a named tuple nor `nothing`. + +A fallback returns the usual named tuple merge of the dictionary values, ignoring any +`nothing` values. It is the responsibility of the implementation to ensure individual +reports will never have clashing keys. + +""" +function report(model, report_given_method) + + # Note that we want to avoid copying values in each individual report named tuple, and + # merge the reports in a reproducible order. + + methods = collect(keys(report_given_method)) |> sort! + reports = [report_given_method[method] for method in methods] + + return merge(reports...) +end + diff --git a/test/model_api.jl b/test/model_api.jl index 489584c..5875c9e 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -5,7 +5,6 @@ end f0::Int end - mutable struct APIx1 <: Static end @testset "selectrows(model, data...)" begin @@ -95,3 +94,10 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end @test yhat == fill(DummyUnivariateFinite(), 3) end + +@testset "fallback for `report()` method" begin + report_given_method = + OrderedCollections.OrderedDict(:predict=>(y=7,), :fit=>(x=1, z=3)) + @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == + (x=1, z=3, y=7) +end diff --git a/test/runtests.jl b/test/runtests.jl index 69eefe7..693f2da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using ScientificTypesBase, ScientificTypes using Tables, Distances, CategoricalArrays, InteractiveUtils import DataFrames: DataFrame import Markdown +import OrderedCollections const M = MLJModelInterface const FI = M.FullInterface From d6983094ad773672cdde17dc7bd87c1b5e63c80b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 31 Aug 2022 13:14:08 +1200 Subject: [PATCH 10/22] add forgotten switch for `nothing` --- src/model_api.jl | 3 +-- test/model_api.jl | 6 +++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index 31667ef..9c2f3ed 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -213,8 +213,7 @@ function report(model, report_given_method) # merge the reports in a reproducible order. methods = collect(keys(report_given_method)) |> sort! - reports = [report_given_method[method] for method in methods] + reports = [_named_tuple(report_given_method[method]) for method in methods] return merge(reports...) end - diff --git a/test/model_api.jl b/test/model_api.jl index 5875c9e..4347d87 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -97,7 +97,11 @@ end @testset "fallback for `report()` method" begin report_given_method = - OrderedCollections.OrderedDict(:predict=>(y=7,), :fit=>(x=1, z=3)) + OrderedCollections.OrderedDict( + :predict=>(y=7,), + :fit=>(x=1, z=3), + :transform=>nothing, + ) @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == (x=1, z=3, y=7) end From 3b477ecc645256d4e0c31fde4a21cbb2f8847869 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 7 Sep 2022 14:24:25 +1200 Subject: [PATCH 11/22] handle clashes in keys of reports --- src/model_api.jl | 24 ++++++++++++++++++------ test/model_api.jl | 9 +++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index 9c2f3ed..bfe9a33 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -187,6 +187,8 @@ function feature_importances end _named_tuple(named_tuple) = named_tuple _named_tuple(::Nothing) = NamedTuple() +_keys(named_tuple) = keys(named_tuple) +_keys(::Nothing) = () """ MLJModelInterface.report(model, report_given_method) @@ -195,25 +197,35 @@ Merge the reports in the dictionary `report_given_method` into a single property-accessible object. The possible keys of the dictionary are `:fit` and the symbolic names of MLJModelInterface.jl operations, such as `:predict` or `:transform`. Each value will be the `report` component returned by a training method -(`fit` or `update`), in the case of `:fit`, or the corresponding operation. +(`fit` or `update`) dispatched on the `model` type, in the case of `:fit`, or the +corresponding operation. # New model implementations -Overloading this method is optional, unless some value in the dictionary -`report_given_method` is possibly neither a named tuple nor `nothing`. +Overloading this method is optional, unless `fit`/`update` or an operation generates a +report that is niether a named tuple nor `nothing`. A fallback returns the usual named tuple merge of the dictionary values, ignoring any -`nothing` values. It is the responsibility of the implementation to ensure individual -reports will never have clashing keys. +`nothing` values, assuming there are no conflicts between the keys of the dictionary +values. In that case, each report is first wrapped in a named tuple with one entry, such +as `(predict=predict_report,)`. """ function report(model, report_given_method) + return_keys = vcat(collect.(_keys.(values(report_given_method)))...) + need_to_wrap = return_keys != unique(return_keys) + # Note that we want to avoid copying values in each individual report named tuple, and # merge the reports in a reproducible order. methods = collect(keys(report_given_method)) |> sort! - reports = [_named_tuple(report_given_method[method]) for method in methods] + reports = map(methods) do method + tup = _named_tuple(report_given_method[method]) + isempty(tup) ? NamedTuple() : + need_to_wrap ? NamedTuple{(method,)}((tup,)) : + tup + end return merge(reports...) end diff --git a/test/model_api.jl b/test/model_api.jl index 4347d87..30c6748 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -104,4 +104,13 @@ end ) @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == (x=1, z=3, y=7) + + report_given_method = + OrderedCollections.OrderedDict( + :predict=>(y=7,), + :fit=>(y=1, z=3), + :transform=>nothing, + ) + @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == + (fit=(y=1, z=3), predict=(y=7,)) end From 06c5971fd147b4b10026224ca4a224b4f79a3e39 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 12 Sep 2022 09:47:38 +1200 Subject: [PATCH 12/22] improved key clash handling --- src/model_api.jl | 14 +++++++------- test/model_api.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index bfe9a33..32d5830 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -194,8 +194,8 @@ _keys(::Nothing) = () MLJModelInterface.report(model, report_given_method) Merge the reports in the dictionary `report_given_method` into a single -property-accessible object. The possible keys of the dictionary are `:fit` and the -symbolic names of MLJModelInterface.jl operations, such as `:predict` or +property-accessible object. It is supposed that the possible keys of the dictionary are +`:fit` and the symbolic names of MLJModelInterface.jl operations, such as `:predict` or `:transform`. Each value will be the `report` component returned by a training method (`fit` or `update`) dispatched on the `model` type, in the case of `:fit`, or the corresponding operation. @@ -203,12 +203,12 @@ corresponding operation. # New model implementations Overloading this method is optional, unless `fit`/`update` or an operation generates a -report that is niether a named tuple nor `nothing`. +report that is neither a named tuple nor `nothing`. A fallback returns the usual named tuple merge of the dictionary values, ignoring any -`nothing` values, assuming there are no conflicts between the keys of the dictionary -values. In that case, each report is first wrapped in a named tuple with one entry, such -as `(predict=predict_report,)`. +`nothing` values, and assuming there are no conflicts between the keys of the dictionary +values (the individual reports). If there is a key conflict, all operation reports are +first wrapped in a named tuple of length one, as in `(predict=predict_report,)`. """ function report(model, report_given_method) @@ -223,7 +223,7 @@ function report(model, report_given_method) reports = map(methods) do method tup = _named_tuple(report_given_method[method]) isempty(tup) ? NamedTuple() : - need_to_wrap ? NamedTuple{(method,)}((tup,)) : + (need_to_wrap && method !== :fit) ? NamedTuple{(method,)}((tup,)) : tup end diff --git a/test/model_api.jl b/test/model_api.jl index 30c6748..72f7f92 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -112,5 +112,5 @@ end :transform=>nothing, ) @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == - (fit=(y=1, z=3), predict=(y=7,)) + (y=1, z=3, predict=(y=7,)) end From c2aacbab5ab6980fd3de01a7c73127d4fac9a896 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 12 Sep 2022 11:48:19 +1200 Subject: [PATCH 13/22] empty merged reports should be replaced with `nothing` in `report()` fallback --- src/model_api.jl | 4 +++- test/model_api.jl | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/model_api.jl b/src/model_api.jl index 32d5830..a9f1572 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -227,5 +227,7 @@ function report(model, report_given_method) tup end - return merge(reports...) + ret = merge(reports...) + isempty(ret) && return nothing + return ret end diff --git a/test/model_api.jl b/test/model_api.jl index 72f7f92..bdb2ca9 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -113,4 +113,9 @@ end ) @test MLJModelInterface.report(APIx0(f0=1), report_given_method) == (y=1, z=3, predict=(y=7,)) + + @test MLJModelInterface.report( + APIx0(f0=1), + OrderedCollections.OrderedDict(:fit => nothing, :transform => NamedTuple()), + ) == nothing end From d4a3419041c811c4f98893e82bbe1e28cea71f89 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 15 Sep 2022 12:01:12 +1200 Subject: [PATCH 14/22] overload fitted_params(::Static, ..) = nothing --- src/model_api.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model_api.jl b/src/model_api.jl index a9f1572..4f493b8 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -100,7 +100,7 @@ part of the tuple returned by `fit`. """ fitted_params(::Model, fitresult) = (fitresult=fitresult,) - +fitted_params(::Static, ::Nothing) = nothing """ predict(model, fitresult, new_data...) From 49e128be2436f49ac7b711a4540971b57112a2f1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 15 Sep 2022 16:40:39 +1200 Subject: [PATCH 15/22] make merge fallback more robust --- src/model_api.jl | 26 ++++++++++++++++---------- test/model_api.jl | 7 +++++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index 4f493b8..78c7d5b 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -185,8 +185,9 @@ If for some reason a model is sometimes unable to report feature importances the """ function feature_importances end -_named_tuple(named_tuple) = named_tuple +_named_tuple(named_tuple::NamedTuple) = named_tuple _named_tuple(::Nothing) = NamedTuple() +_named_tuple(something_else) = (report=something_else,) _keys(named_tuple) = keys(named_tuple) _keys(::Nothing) = () @@ -194,32 +195,37 @@ _keys(::Nothing) = () MLJModelInterface.report(model, report_given_method) Merge the reports in the dictionary `report_given_method` into a single -property-accessible object. It is supposed that the possible keys of the dictionary are +property-accessible object. It is supposed that the keys of the dictionary are `:fit` and the symbolic names of MLJModelInterface.jl operations, such as `:predict` or `:transform`. Each value will be the `report` component returned by a training method (`fit` or `update`) dispatched on the `model` type, in the case of `:fit`, or the -corresponding operation. +report component returned by an operation that supports reporting. # New model implementations -Overloading this method is optional, unless `fit`/`update` or an operation generates a -report that is neither a named tuple nor `nothing`. +Overloading this method is optional, unless the model generates reports that are neither +named tuples nor `nothing`. -A fallback returns the usual named tuple merge of the dictionary values, ignoring any -`nothing` values, and assuming there are no conflicts between the keys of the dictionary -values (the individual reports). If there is a key conflict, all operation reports are -first wrapped in a named tuple of length one, as in `(predict=predict_report,)`. +Assuming each dictionary value is a named tuple or `nothing`, the fallback returns the +usual named tuple merge of the dictionary values, ignoring any `nothing` values, and +assuming there are no conflicts between the keys of the dictionary values (the individual +reports). If there is a key conflict, all operation reports are first wrapped in a named +tuple of length one, as in `(predict=predict_report,)`. + +If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as +`(report=value, )` """ function report(model, report_given_method) return_keys = vcat(collect.(_keys.(values(report_given_method)))...) - need_to_wrap = return_keys != unique(return_keys) # Note that we want to avoid copying values in each individual report named tuple, and # merge the reports in a reproducible order. methods = collect(keys(report_given_method)) |> sort! + length(methods) == 1 && return report_given_method[only(methods)] + need_to_wrap = return_keys != unique(return_keys) reports = map(methods) do method tup = _named_tuple(report_given_method[method]) isempty(tup) ? NamedTuple() : diff --git a/test/model_api.jl b/test/model_api.jl index bdb2ca9..dd6664a 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -118,4 +118,11 @@ end APIx0(f0=1), OrderedCollections.OrderedDict(:fit => nothing, :transform => NamedTuple()), ) == nothing + + @test MLJModelInterface.report( + APIx0(f0=1), + OrderedCollections.OrderedDict(:fit => 42), + ) == 42 + + end From 25986d230b614e3d99539e113d7f9ecf0a48fd70 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 16 Sep 2022 08:14:53 +1200 Subject: [PATCH 16/22] make sure empty tuples are scrubbed to `nothing` in report return value --- src/model_api.jl | 13 ++++++------- test/model_api.jl | 12 +++++++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index 78c7d5b..326b725 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -188,6 +188,8 @@ function feature_importances end _named_tuple(named_tuple::NamedTuple) = named_tuple _named_tuple(::Nothing) = NamedTuple() _named_tuple(something_else) = (report=something_else,) +_scrub(x) = x +_scrub(x::NamedTuple) = isempty(x) ? nothing : x _keys(named_tuple) = keys(named_tuple) _keys(::Nothing) = () @@ -213,7 +215,7 @@ reports). If there is a key conflict, all operation reports are first wrapped in tuple of length one, as in `(predict=predict_report,)`. If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as -`(report=value, )` +`(report=value, )` before merging. """ function report(model, report_given_method) @@ -224,16 +226,13 @@ function report(model, report_given_method) # merge the reports in a reproducible order. methods = collect(keys(report_given_method)) |> sort! - length(methods) == 1 && return report_given_method[only(methods)] + length(methods) == 1 && return _scrub(report_given_method[only(methods)]) need_to_wrap = return_keys != unique(return_keys) reports = map(methods) do method tup = _named_tuple(report_given_method[method]) isempty(tup) ? NamedTuple() : (need_to_wrap && method !== :fit) ? NamedTuple{(method,)}((tup,)) : - tup + tup end - - ret = merge(reports...) - isempty(ret) && return nothing - return ret + return _scrub(merge(reports...)) end diff --git a/test/model_api.jl b/test/model_api.jl index dd6664a..9bacc63 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -117,12 +117,22 @@ end @test MLJModelInterface.report( APIx0(f0=1), OrderedCollections.OrderedDict(:fit => nothing, :transform => NamedTuple()), - ) == nothing + ) |> isnothing @test MLJModelInterface.report( APIx0(f0=1), OrderedCollections.OrderedDict(:fit => 42), ) == 42 + @test MLJModelInterface.report( + APIx0(f0=1), + OrderedCollections.OrderedDict(:fit => nothing), + ) |> isnothing + + @test MLJModelInterface.report( + APIx0(f0=1), + OrderedCollections.OrderedDict(:fit => NamedTuple()), + ) |> isnothing + end From 65082874a74230136a1775edad34b6ccb662052f Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 16 Sep 2022 16:03:22 +1200 Subject: [PATCH 17/22] bump compat julia = "1.6" and update ci --- .github/workflows/ci.yml | 1 - Project.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 24170b9..9b5d87f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,6 @@ jobs: fail-fast: false matrix: version: - - '1.0' - '1.6' - '1' os: diff --git a/Project.toml b/Project.toml index 483f7a6..8a1f80c 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9" [compat] ScientificTypesBase = "3.0" StatisticalTraits = "3.2" -julia = "1" +julia = "1.6" [extras] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" From ed0fd8c6d365a8cdb21d82ca16ad0bbaa8bb22de Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 16 Sep 2022 16:20:52 +1200 Subject: [PATCH 18/22] trivial commit --- test/model_api.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/model_api.jl b/test/model_api.jl index 9bacc63..452a031 100644 --- a/test/model_api.jl +++ b/test/model_api.jl @@ -136,3 +136,5 @@ end end + + From b8ad6f703073c0ec1da762251a1ca1f9936e9df2 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 23 Sep 2022 10:33:42 +1200 Subject: [PATCH 19/22] rm redundant `const` --- src/metadata_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metadata_utils.jl b/src/metadata_utils.jl index d005a67..9fda47a 100644 --- a/src/metadata_utils.jl +++ b/src/metadata_utils.jl @@ -65,7 +65,7 @@ function _extend!(program::Expr, trait::Symbol, value, T) end end -const depwarn_docstring(T) = +depwarn_docstring(T) = """ Regarding $T: `metadata_model` should not be called with the keyword argument `descr` From 82aa20a81f152598be5511cfe58cfbb865f06ef6 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 23 Sep 2022 10:44:46 +1200 Subject: [PATCH 20/22] improve docstring oops --- src/model_api.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index 326b725..21e55e5 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -197,11 +197,11 @@ _keys(::Nothing) = () MLJModelInterface.report(model, report_given_method) Merge the reports in the dictionary `report_given_method` into a single -property-accessible object. It is supposed that the keys of the dictionary are -`:fit` and the symbolic names of MLJModelInterface.jl operations, such as `:predict` or -`:transform`. Each value will be the `report` component returned by a training method -(`fit` or `update`) dispatched on the `model` type, in the case of `:fit`, or the -report component returned by an operation that supports reporting. +property-accessible object. It is supposed that each key of the dictionary is either +`:fit` or the name of an operation, such as `:predict` or `:transform`. Each value will be +the `report` component returned by a training method (`fit` or `update`) dispatched on the +`model` type, in the case of `:fit`, or the report component returned by an operation that +supports reporting. # New model implementations @@ -212,7 +212,7 @@ Assuming each dictionary value is a named tuple or `nothing`, the fallback retur usual named tuple merge of the dictionary values, ignoring any `nothing` values, and assuming there are no conflicts between the keys of the dictionary values (the individual reports). If there is a key conflict, all operation reports are first wrapped in a named -tuple of length one, as in `(predict=predict_report,)`. +tuple of length one, as in `(predict=predict_report,)`. A `:fit` report is never wrapped. If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as `(report=value, )` before merging. From 2319e858c24bdb52f674a0fb269f69ab9063c723 Mon Sep 17 00:00:00 2001 From: "Anthony Blaom, PhD" Date: Fri, 23 Sep 2022 10:48:45 +1200 Subject: [PATCH 21/22] tweak docstring Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> --- src/model_api.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/model_api.jl b/src/model_api.jl index 21e55e5..595d905 100644 --- a/src/model_api.jl +++ b/src/model_api.jl @@ -208,10 +208,11 @@ supports reporting. Overloading this method is optional, unless the model generates reports that are neither named tuples nor `nothing`. -Assuming each dictionary value is a named tuple or `nothing`, the fallback returns the -usual named tuple merge of the dictionary values, ignoring any `nothing` values, and -assuming there are no conflicts between the keys of the dictionary values (the individual -reports). If there is a key conflict, all operation reports are first wrapped in a named +Assuming each value in the `report_given_method` dictionary is either a named tuple +or `nothing`, and there are no conflicts between the keys of the dictionary values +(the individual reports), the fallback returns the usual named tuple merge of the +dictionary values, ignoring any `nothing` value. If there is a key conflict, all operation +reports are first wrapped in a named tuple of length one, as in `(predict=predict_report,)`. A `:fit` report is never wrapped. If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as From a346af91f69760831ce7bd65155e5d2bf51b88e1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 5 Oct 2022 12:11:05 +1300 Subject: [PATCH 22/22] bump 1.7.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8a1f80c..02ec4a9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJModelInterface" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" authors = ["Thibaut Lienart and Anthony Blaom"] -version = "1.6.0" +version = "1.7.0" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"