Skip to content

Commit 140fb7b

Browse files
committed
add catch for missing target in resampling
1 parent 2209563 commit 140fb7b

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/resampling.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ err_ambiguous_operation(model, measure) = ArgumentError(
6868
"Possible value(s) are: $PREDICT_OPERATIONS_STRING. "
6969
)
7070

71-
ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
71+
const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
7272
"""
7373
7474
The `prediction_type` of your model needs to be one of: `:deterministic`,
@@ -81,6 +81,18 @@ ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
8181
"""
8282
)
8383

84+
const ERR_NEED_TARGET = ArgumentError(
85+
"""
86+
87+
To evaluate a model's performance you must provide a target variable `y`, as in
88+
`evaluate(model, X, y; options...)` or
89+
90+
mach = machine(model, X, y)
91+
evaluate!(mach; options...)
92+
93+
"""
94+
)
95+
8496
# ==================================================================
8597
## RESAMPLING STRATEGIES
8698

@@ -1170,6 +1182,8 @@ function evaluate!(
11701182
# weights, measures, operations, and dispatches a
11711183
# strategy-specific `evaluate!`
11721184

1185+
length(mach.args) > 1 || throw(ERR_NEED_TARGET)
1186+
11731187
repeats > 0 || error("Need `repeats > 0`. ")
11741188

11751189
if resampling isa TrainTestPairs

test/resampling.jl

+6
Original file line numberDiff line numberDiff line change
@@ -948,14 +948,20 @@ end
948948

949949
struct PredictingTransformer <:Unsupervised end
950950
MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
951+
MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing)
951952
MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X))
953+
MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing
952954
MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic
953955

954956
@testset "`Unsupervised` model with a predict" begin
955957
X = rand(10)
956958
y = fill(42.0, 10)
957959
e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2)
958960
@test e.measurement[1] 0
961+
@test_throws(
962+
MLJBase.ERR_NEED_TARGET,
963+
evaluate(PredictingTransformer(), X, measure=l2),
964+
)
959965
end
960966

961967

0 commit comments

Comments
 (0)