Skip to content

Commit da15fb3

Browse files
authored
Merge pull request #201 from JuliaAI/explicit-better-checks
Add prediction type check for Explicit strategy
2 parents f5256c5 + 53700b3 commit da15fb3

File tree

5 files changed

+68
-9
lines changed

5 files changed

+68
-9
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.8.0"
4+
version = "0.8.1"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"

src/strategies/explicit.jl

+24-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1+
const WARN_INCONSISTENT_PREDICTION_TYPE =
2+
"Not all models to be evaluated have the same prediction type, and this may "*
3+
"cause problems for some measures. For example, a probabilistic metric "*
4+
"like `log_loss` cannot be applied to a model making point (deterministic) "*
5+
"predictions. Inspect the prediction type with "*
6+
"`prediction_type(model)`. "
7+
18
mutable struct Explicit <: TuningStrategy end
29

310
struct ExplicitState{R, N}
411
range::R # a model-generating iterator
5-
next::N # to hold output of `iterate(range)`
12+
next::N # to hold output of `iterate(range)`
13+
prediction_type::Symbol
14+
user_warned::Bool
615
end
716

8-
ExplictState(r::R, n::N) where {R,N} = ExplicitState{R, Union{Nothing, N}}(r, n)
9-
1017
function MLJTuning.setup(tuning::Explicit, model, range, n, verbosity)
1118
next = iterate(range)
12-
return ExplicitState(range, next)
19+
return ExplicitState(range, next, MLJBase.prediction_type(model), false)
1320
end
1421

1522
# models! returns as many models as possible but no more than `n_remaining`:
@@ -20,11 +27,21 @@ function MLJTuning.models(tuning::Explicit,
2027
n_remaining,
2128
verbosity)
2229

23-
range, next = state.range, state.next
30+
range, next, prediction_type, user_warned =
31+
state.range, state.next, state.prediction_type, state.user_warned
32+
33+
function check(m)
34+
if !user_warned && verbosity > -1 && MLJBase.prediction_type(m) != prediction_type
35+
@warn WARN_INCONSISTENT_PREDICTION_TYPE
36+
user_warned = true
37+
end
38+
end
2439

2540
next === nothing && return nothing, state
2641

2742
m, s = next
43+
check(m)
44+
2845
models = Any[m, ] # types not known until run-time
2946

3047
next = iterate(range, s)
@@ -33,12 +50,13 @@ function MLJTuning.models(tuning::Explicit,
3350
while i < n_remaining
3451
next === nothing && break
3552
m, s = next
53+
check(m)
3654
push!(models, m)
3755
i += 1
3856
next = iterate(range, s)
3957
end
4058

41-
new_state = ExplicitState(range, next)
59+
new_state = ExplicitState(range, next, prediction_type, user_warned)
4260

4361
return models, new_state
4462

src/tuned_models.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,11 @@ function event!(metamodel,
431431
state)
432432
model = _first(metamodel)
433433
metadata = _last(metamodel)
434+
force = typeof(resampling_machine.model.model) !=
435+
typeof(model)
434436
resampling_machine.model.model = model
435437
verb = (verbosity >= 2 ? verbosity - 3 : verbosity - 1)
436-
fit!(resampling_machine, verbosity=verb)
438+
fit!(resampling_machine; verbosity=verb, force)
437439
E = evaluate(resampling_machine)
438440
entry0 = (model = model,
439441
measure = E.measure,

test/strategies/explicit.jl

+39
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
good = KNNClassifier(K=2)
22
bad = KNNClassifier(K=10)
33
ugly = ConstantClassifier()
4+
evil = DeterministicConstantClassifier()
45

56
r = [good, bad, ugly]
67

@@ -44,4 +45,42 @@ X, y = make_blobs(rng=rng)
4445
@test_throws ArgumentError TunedModel(; models=[dcc, dcc])
4546
end
4647

48+
r = [good, bad, evil, ugly]
49+
50+
@testset "inconsistent prediction types" begin
51+
# case where different predictions types is actually okay (but still
52+
# a warning is issued):
53+
tmodel = TunedModel(
54+
models=r,
55+
resampling = Holdout(),
56+
measure=accuracy,
57+
)
58+
@test_logs(
59+
(:warn, MLJTuning.WARN_INCONSISTENT_PREDICTION_TYPE),
60+
MLJBase.fit(tmodel, 0, X, y),
61+
);
62+
63+
# verbosity = -1 suppresses the warning:
64+
@test_logs(
65+
MLJBase.fit(tmodel, -1, X, y),
66+
);
67+
68+
# case where there really is a problem with different prediction types:
69+
tmodel = TunedModel(
70+
models=r,
71+
resampling = Holdout(),
72+
measure=log_loss,
73+
)
74+
@test_logs(
75+
(:warn, MLJTuning.WARN_INCONSISTENT_PREDICTION_TYPE),
76+
(:error,),
77+
(:info,),
78+
(:info,),
79+
@test_throws(
80+
ArgumentError, # indicates the problem is with incompatible measure
81+
MLJBase.fit(tmodel, 0, X, y),
82+
)
83+
)
84+
end
85+
4786
true

test/tuned_models.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ results = [(evaluate(model, X, y,
7979
tm = TunedModel(
8080
models=r,
8181
resampling=CV(nfolds=2),
82-
measures=cross_entropy
82+
measures=cross_entropy,
8383
)
8484
@test_logs((:error, r"Problem"),
8585
(:info, r""),

0 commit comments

Comments
 (0)