Skip to content

Commit 4e8c087

Browse files
authored
Merge pull request #988 from JuliaAI/dev
For a 1.7 release
2 parents 0849be7 + d65ed1f commit 4e8c087

File tree

3 files changed

+67
-13
lines changed

3 files changed

+67
-13
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "1.6"
4+
version = "1.7.0"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/resampling.jl

+35-11
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError(
3131
_ambiguous_operation(model, measure) =
3232
"`$measure` does not support a `model` with "*
3333
"`prediction_type(model) == :$(prediction_type(model))`. "
34-
err_ambiguous_operation(model, measure) = ArgumentError(
35-
_ambiguous_operation(model, measure)*
36-
"\nUnable to infer an appropriate operation for `$measure`. "*
37-
"Explicitly specify `operation=...` or `operations=...`. ")
3834
err_incompatible_prediction_types(model, measure) = ArgumentError(
3935
_ambiguous_operation(model, measure)*
4036
"If your model is truly making probabilistic predictions, try explicitly "*
@@ -65,11 +61,37 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError(
6561
"and so is not supported by `$measure`. "*LOG_AVOID
6662
)
6763

68-
# ==================================================================
69-
## MODEL TYPES THAT CAN BE EVALUATED
64+
err_ambiguous_operation(model, measure) = ArgumentError(
65+
_ambiguous_operation(model, measure)*
66+
"\nUnable to infer an appropriate operation for `$measure`. "*
67+
"Explicitly specify `operation=...` or `operations=...`. "*
68+
"Possible value(s) are: $PREDICT_OPERATIONS_STRING. "
69+
)
70+
71+
const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
72+
"""
7073
71-
# not exported:
72-
const Measurable = Union{Supervised, Annotator}
74+
The `prediction_type` of your model needs to be one of: `:deterministic`,
75+
`:probabilistic`, or `:interval`. Does your model implement one of these operations:
76+
$PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...`
77+
or `operations=...` (and consider posting an issue to have the model review it's
78+
definition of `MLJModelInterface.prediction_type`). Otherwise, performance
79+
evaluation is not supported.
80+
81+
"""
82+
)
83+
84+
const ERR_NEED_TARGET = ArgumentError(
85+
"""
86+
87+
To evaluate a model's performance you must provide a target variable `y`, as in
88+
`evaluate(model, X, y; options...)` or
89+
90+
mach = machine(model, X, y)
91+
evaluate!(mach; options...)
92+
93+
"""
94+
)
7395

7496
# ==================================================================
7597
## RESAMPLING STRATEGIES
@@ -987,7 +1009,7 @@ function _actual_operations(operation::Nothing,
9871009
throw(err_ambiguous_operation(model, m))
9881010
end
9891011
else
990-
throw(err_ambiguous_operation(model, m))
1012+
throw(ERR_UNSUPPORTED_PREDICTION_TYPE)
9911013
end
9921014
end
9931015
end
@@ -1137,7 +1159,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref),
11371159
11381160
"""
11391161
function evaluate!(
1140-
mach::Machine{<:Measurable};
1162+
mach::Machine;
11411163
resampling=CV(),
11421164
measures=nothing,
11431165
measure=measures,
@@ -1160,6 +1182,8 @@ function evaluate!(
11601182
# weights, measures, operations, and dispatches a
11611183
# strategy-specific `evaluate!`
11621184

1185+
length(mach.args) > 1 || throw(ERR_NEED_TARGET)
1186+
11631187
repeats > 0 || error("Need `repeats > 0`. ")
11641188

11651189
if resampling isa TrainTestPairs
@@ -1235,7 +1259,7 @@ Returns a [`PerformanceEvaluation`](@ref) object.
12351259
See also [`evaluate!`](@ref).
12361260
12371261
"""
1238-
evaluate(model::Measurable, args...; cache=true, kwargs...) =
1262+
evaluate(model::Model, args...; cache=true, kwargs...) =
12391263
evaluate!(machine(model, args...; cache=cache); kwargs...)
12401264

12411265
# -------------------------------------------------------------------

test/resampling.jl

+31-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ end
2525
struct DummyInterval <: Interval end
2626
dummy_interval=DummyInterval()
2727

28+
struct GoofyTransformer <: Unsupervised end
29+
2830
dummy_measure_det(yhat, y) = 42
2931
API.@trait(
3032
typeof(dummy_measure_det),
@@ -115,6 +117,12 @@ API.@trait(
115117
MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()),
116118
MLJBase._actual_operations(nothing,
117119
[LogLoss(), ], dummy_interval, 1))
120+
121+
# model does not have a valid `prediction_type`:
122+
@test_throws(
123+
MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE,
124+
MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0),
125+
)
118126
end
119127

120128
@everywhere begin
@@ -935,7 +943,29 @@ end
935943
end
936944
end
937945

938-
# DUMMY LOGGER
946+
947+
# # TRANSFORMER WITH PREDICT
948+
949+
struct PredictingTransformer <:Unsupervised end
950+
MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
951+
MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing)
952+
MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X))
953+
MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing
954+
MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic
955+
956+
@testset "`Unsupervised` model with a predict" begin
957+
X = rand(10)
958+
y = fill(42.0, 10)
959+
e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2)
960+
@test e.measurement[1] 0
961+
@test_throws(
962+
MLJBase.ERR_NEED_TARGET,
963+
evaluate(PredictingTransformer(), X, measure=l2),
964+
)
965+
end
966+
967+
968+
# # DUMMY LOGGER
939969

940970
struct DummyLogger end
941971

0 commit comments

Comments
 (0)