Skip to content

Commit

Permalink
Merge pull request #163 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 1.7.0 release
  • Loading branch information
ablaom authored Oct 4, 2022
2 parents 3571861 + a346af9 commit adb1258
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 14 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.6'
- '1'
os:
Expand Down
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.6.0"
version = "1.7.0"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -11,17 +11,18 @@ StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
[compat]
ScientificTypesBase = "3.0"
StatisticalTraits = "3.2"
julia = "1"
julia = "1.6"

[extras]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "ScientificTypes", "Tables", "Test"]
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "OrderedCollections", "ScientificTypes", "Tables", "Test"]
17 changes: 11 additions & 6 deletions src/metadata_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,16 @@ function _extend!(program::Expr, trait::Symbol, value, T)
end
end

const DEPWARN_DOCSTRING =
"`metadata_model` should not be called with the keyword argument "*
"`descr` or `docstring`. Implementers of the MLJ model interface "*
"should instead create an MLJ-compliant docstring in the usual way. "*
"See https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings for details. "
depwarn_docstring(T) =
"""
Regarding $T: `metadata_model` should not be called with the keyword argument `descr`
or `docstring`. Implementers of the MLJ model interface should instead create an
MLJ-compliant docstring in the usual way. See
https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings
for details.
"""

"""
metadata_model(T; args...)
Expand Down Expand Up @@ -122,7 +127,7 @@ function metadata_model(
supports_training_losses::Union{Nothing,Bool}=nothing,
reports_feature_importances::Union{Nothing,Bool}=nothing,
)
docstring === nothing || Base.depwarn(DEPWARN_DOCSTRING, :metadata_model)
docstring === nothing || Base.depwarn(depwarn_docstring(T), :metadata_model)

program = quote end

Expand Down
57 changes: 56 additions & 1 deletion src/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ part of the tuple returned by `fit`.
"""
fitted_params(::Model, fitresult) = (fitresult=fitresult,)

fitted_params(::Static, ::Nothing) = nothing
"""
predict(model, fitresult, new_data...)
Expand Down Expand Up @@ -173,6 +173,8 @@ the feature importances from the model's `fitresult` and `report` as an
abstract vector of `feature::Symbol => importance::Real` pairs
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).
# New model implementations
The following trait overload is also required:
`MLJModelInterface.reports_feature_importances(::Type{<:M}) = true`
Expand All @@ -182,3 +184,56 @@ If for some reason a model is sometimes unable to report feature importances the
"""
function feature_importances end

_named_tuple(named_tuple::NamedTuple) = named_tuple
_named_tuple(::Nothing) = NamedTuple()
_named_tuple(something_else) = (report=something_else,)
_scrub(x) = x
_scrub(x::NamedTuple) = isempty(x) ? nothing : x
_keys(named_tuple) = keys(named_tuple)
_keys(::Nothing) = ()

"""
MLJModelInterface.report(model, report_given_method)
Merge the reports in the dictionary `report_given_method` into a single
property-accessible object. It is supposed that each key of the dictionary is either
`:fit` or the name of an operation, such as `:predict` or `:transform`. Each value will be
the `report` component returned by a training method (`fit` or `update`) dispatched on the
`model` type, in the case of `:fit`, or the report component returned by an operation that
supports reporting.
# New model implementations
Overloading this method is optional, unless the model generates reports that are neither
named tuples nor `nothing`.
Assuming each value in the `report_given_method` dictionary is either a named tuple
or `nothing`, and there are no conflicts between the keys of the dictionary values
(the individual reports), the fallback returns the usual named tuple merge of the
dictionary values, ignoring any `nothing` value. If there is a key conflict, all operation
reports are first wrapped in a named
tuple of length one, as in `(predict=predict_report,)`. A `:fit` report is never wrapped.
If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as
`(report=value, )` before merging.
"""
function report(model, report_given_method)

return_keys = vcat(collect.(_keys.(values(report_given_method)))...)

# Note that we want to avoid copying values in each individual report named tuple, and
# merge the reports in a reproducible order.

methods = collect(keys(report_given_method)) |> sort!
length(methods) == 1 && return _scrub(report_given_method[only(methods)])
need_to_wrap = return_keys != unique(return_keys)
reports = map(methods) do method
tup = _named_tuple(report_given_method[method])
isempty(tup) ? NamedTuple() :
(need_to_wrap && method !== :fit) ? NamedTuple{(method,)}((tup,)) :
tup
end
return _scrub(merge(reports...))
end
4 changes: 2 additions & 2 deletions test/data_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ end
eval(:(module UserSide
import MLJModelInterface: metadata_model, metadata_pkg
struct A end
descr = "something"
human_name = "Big Foot"
# Smoke tests.
metadata_model(A; descr=descr)
metadata_model(A; human_name)
metadata_pkg(A)
end))
end
45 changes: 44 additions & 1 deletion test/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ end
f0::Int
end


mutable struct APIx1 <: Static end

@testset "selectrows(model, data...)" begin
Expand Down Expand Up @@ -95,3 +94,47 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end
@test yhat == fill(DummyUnivariateFinite(), 3)

end

@testset "fallback for `report()` method" begin
report_given_method =
OrderedCollections.OrderedDict(
:predict=>(y=7,),
:fit=>(x=1, z=3),
:transform=>nothing,
)
@test MLJModelInterface.report(APIx0(f0=1), report_given_method) ==
(x=1, z=3, y=7)

report_given_method =
OrderedCollections.OrderedDict(
:predict=>(y=7,),
:fit=>(y=1, z=3),
:transform=>nothing,
)
@test MLJModelInterface.report(APIx0(f0=1), report_given_method) ==
(y=1, z=3, predict=(y=7,))

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => nothing, :transform => NamedTuple()),
) |> isnothing

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => 42),
) == 42

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => nothing),
) |> isnothing

@test MLJModelInterface.report(
APIx0(f0=1),
OrderedCollections.OrderedDict(:fit => NamedTuple()),
) |> isnothing


end


1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ScientificTypesBase, ScientificTypes
using Tables, Distances, CategoricalArrays, InteractiveUtils
import DataFrames: DataFrame
import Markdown
import OrderedCollections

const M = MLJModelInterface
const FI = M.FullInterface
Expand Down

0 comments on commit adb1258

Please sign in to comment.