Skip to content

Commit 4d37eed

Browse files
committed
make pipelines support Unsupervised with target in fit
1 parent cb15208 commit 4d37eed

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

src/composition/models/pipelines.jl

+16-2
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,15 @@ implements it (some clustering models). Similarly, calling `transform`
225225
on a supervised pipeline calls `transform` on the supervised
226226
component.
227227
228+
### Transformers that need a target in training
229+
230+
Some transformers that have type `Unsupervised` (so that the output of `transform` is
231+
propagated in pipelines) also see a target variable in training. An example are so-called
232+
target encoders (which transform categorical input features, based on some target
233+
observations). Provided they appear before any `Supervised` component in the pipelines,
234+
such models are supported. Of course a target must be provided whenever training such a
235+
pipeline, whether or not it contains a `Supervised` component.
236+
228237
### Optional key-word arguments
229238
230239
- `prediction_type` -
@@ -444,9 +453,13 @@ function extend(front::Front{Pred}, ::Static, name, cache, args...)
444453
Front(transform(mach, active(front)), front.transform, Pred())
445454
end
446455

447-
function extend(front::Front{Trans}, component::Unsupervised, name, cache, args...)
456+
function extend(front::Front{Trans}, component::Unsupervised, name, cache, ::Any, sources...)
448457
a = active(front)
449-
mach = machine(name, a; cache=cache)
458+
if target_in_fit(component)
459+
mach = machine(name, a, first(sources); cache=cache)
460+
else
461+
mach = machine(name, a; cache=cache)
462+
end
450463
Front(predict(mach, a), transform(mach, a), Trans())
451464
end
452465

@@ -598,6 +611,7 @@ function MMI.iteration_parameter(pipe::SupervisedPipeline)
598611
end
599612

600613
MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(p))
614+
MMI.target_in_fit(p::SomePipeline) = any(target_in_fit, components(p))
601615

602616
MMI.package_name(::Type{<:SomePipeline}) = "MLJBase"
603617
MMI.load_path(::Type{<:SomePipeline}) = "MLJBase.Pipeline"

test/composition/models/pipelines.jl

+35
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ end
544544
# inverse transform:
545545
p = Pipeline(UnivariateBoxCoxTransformer,
546546
UnivariateStandardizer)
547+
@test !target_in_fit(p)
547548
xtrain = rand(rng, 10)
548549
mach = machine(p, xtrain)
549550
fit!(mach, verbosity=0)
@@ -702,6 +703,40 @@ end
702703
@test Set(features) == Set(keys(X))
703704
end
704705

706+
struct SupervisedTransformer <: Unsupervised end
707+
708+
MLJBase.fit(::SupervisedTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
709+
MLJBase.transform(::SupervisedTransformer, fitresult, X) =
710+
fitresult*MLJBase.matrix(X) |> MLJBase.table
711+
MLJBase.target_in_fit(::Type{<:SupervisedTransformer}) = true
712+
713+
struct DummyTransformer <: Unsupervised end
714+
MLJBase.fit(::DummyTransformer, verbosity, X) = (nothing, nothing, nothing)
715+
MLJBase.transform(::DummyTransformer, fitresult, X) = X
716+
717+
@testset "supervised transformers in a pipeline" begin
718+
X = MLJBase.table((a=fill(10.0, 3),))
719+
y = fill(2, 3)
720+
pipe = SupervisedTransformer() |> DeterministicConstantRegressor()
721+
@test target_in_fit(pipe)
722+
mach = machine(pipe, X, y)
723+
fit!(mach, verbosity=0)
724+
@test predict(mach, X) == fill(2.0, 3)
725+
726+
pipe2 = DummyTransformer |> pipe
727+
@test target_in_fit(pipe2)
728+
mach = machine(pipe2, X, y)
729+
fit!(mach, verbosity=0)
730+
@test predict(mach, X) == fill(2.0, 3)
731+
732+
pipe3 = DummyTransformer |> SupervisedTransformer |> DummyTransformer
733+
@test target_in_fit(pipe3)
734+
mach = machine(pipe3, X, y)
735+
fit!(mach, verbosity=0)
736+
@test transform(mach, X).x1 == fill(20.0, 3)
737+
end
738+
739+
705740
end # module
706741

707742
true

0 commit comments

Comments
 (0)