Skip to content

Commit 928a7d1

Browse files
committed
Merge branch 'dev' into default-logger-docstring
2 parents ecb4322 + 7254ee1 commit 928a7d1

File tree

13 files changed

+163
-96
lines changed

13 files changed

+163
-96
lines changed

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "1.5.0"
4+
version = "1.6"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -47,7 +47,7 @@ DelimitedFiles = "1"
4747
Distributions = "0.25.3"
4848
InvertedIndices = "1"
4949
LearnAPI = "0.1"
50-
MLJModelInterface = "1.10"
50+
MLJModelInterface = "1.11"
5151
Missings = "0.4, 1"
5252
OrderedCollections = "1.1"
5353
Parameters = "0.12"
@@ -58,7 +58,7 @@ Reexport = "1.2"
5858
ScientificTypes = "3"
5959
StatisticalMeasures = "0.1.1"
6060
StatisticalMeasuresBase = "0.1.1"
61-
StatisticalTraits = "3.3"
61+
StatisticalTraits = "3.4"
6262
Statistics = "1"
6363
StatsBase = "0.32, 0.33, 0.34"
6464
Tables = "0.2, 1.0"

src/composition/learning_networks/nodes.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ function _formula(stream, X::Node, depth, indent)
277277
if X.machine !== nothing
278278
print(stream, crind(indent + length(operation_name) - anti))
279279
printstyled(IOContext(stream, :color=>SHOW_COLOR[]),
280-
# handle(X.machine),
280+
#handle(X.machine),
281281
X.machine,
282282
bold=SHOW_COLOR[])
283283
n_args == 0 || print(stream, ", ")

src/composition/learning_networks/signatures.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ See also [`MLJBase.Signature`](@ref).
307307
"""
308308
fitted_params_supplement(signature::Signature) = call_and_copy(fitted_params_nodes(signature))
309309

310-
""" report(signature; supplement=true)
310+
"""
311+
report(signature; supplement=true)
311312
312313
**Private method.**
313314

src/composition/models/pipelines.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ or what `transform` returns if it is `Unsupervised`.
182182
Names for the component fields are automatically generated unless
183183
explicitly specified, as in
184184
185-
```
185+
```julia
186186
Pipeline(encoder=ContinuousEncoder(drop_last=false),
187187
stand=Standardizer())
188188
```
@@ -225,6 +225,15 @@ implements it (some clustering models). Similarly, calling `transform`
225225
on a supervised pipeline calls `transform` on the supervised
226226
component.
227227
228+
### Transformers that need a target in training
229+
230+
Some transformers that have type `Unsupervised` (so that the output of `transform` is
231+
propagated in pipelines) may require a target variable for training. An example are
232+
so-called target encoders (which transform categorical input features, based on some
233+
target observations). Provided they appear before any `Supervised` component in the
234+
pipelines, such models are supported. Of course a target must be provided whenever
235+
training such a pipeline, whether or not it contains a `Supervised` component.
236+
228237
### Optional key-word arguments
229238
230239
- `prediction_type` -
@@ -444,9 +453,13 @@ function extend(front::Front{Pred}, ::Static, name, cache, args...)
444453
Front(transform(mach, active(front)), front.transform, Pred())
445454
end
446455

447-
function extend(front::Front{Trans}, component::Unsupervised, name, cache, args...)
456+
function extend(front::Front{Trans}, component::Unsupervised, name, cache, ::Any, sources...)
448457
a = active(front)
449-
mach = machine(name, a; cache=cache)
458+
if target_in_fit(component)
459+
mach = machine(name, a, first(sources); cache=cache)
460+
else
461+
mach = machine(name, a; cache=cache)
462+
end
450463
Front(predict(mach, a), transform(mach, a), Trans())
451464
end
452465

@@ -598,6 +611,7 @@ function MMI.iteration_parameter(pipe::SupervisedPipeline)
598611
end
599612

600613
MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(p))
614+
MMI.target_in_fit(p::SomePipeline) = any(target_in_fit, components(p))
601615

602616
MMI.package_name(::Type{<:SomePipeline}) = "MLJBase"
603617
MMI.load_path(::Type{<:SomePipeline}) = "MLJBase.Pipeline"

src/data/data.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,18 @@ _isnan(x::Number) = isnan(x)
401401

402402
skipnan(x) = Iterators.filter(!_isnan, x)
403403

404+
isinvalid(x) = ismissing(x) || _isnan(x)
405+
404406
"""
405407
skipinvalid(itr)
406408
407409
Return an iterator over the elements in `itr` skipping `missing` and
408410
`NaN` values. Behaviour is similar to [`skipmissing`](@ref).
409411
412+
"""
413+
skipinvalid(v) = v |> skipmissing |> skipnan
414+
415+
"""
410416
skipinvalid(A, B)
411417
412418
For vectors `A` and `B` of the same length, return a tuple of vectors
@@ -417,10 +423,6 @@ always returns a vector. Does not remove `Missing` from the element
417423
types if present in the original iterators.
418424
419425
"""
420-
skipinvalid(v) = v |> skipmissing |> skipnan
421-
422-
isinvalid(x) = ismissing(x) || _isnan(x)
423-
424426
function skipinvalid(yhat, y)
425427
mask = .!(isinvalid.(yhat) .| isinvalid.(y))
426428
return yhat[mask], y[mask]

src/data/datasets.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ function load_smarket()
199199
end
200200

201201
"""Load a well-known sunspot time series (table with one column).
202-
[https://www.sws.bom.gov.au/Educational/2/3/6]](https://www.sws.bom.gov.au/Educational/2/3/6)
202+
<https://www.sws.bom.gov.au/Educational/2/3/6>
203203
"""
204204
load_sunspots() = load_dataset("sunspots.csv", COERCE_SUNSPOTS)
205205

@@ -250,9 +250,10 @@ macro load_crabs()
250250
end
251251
end
252252

253-
""" Load S&P Stock Market dataset, as used in (An Introduction to
254-
Statistical Learning with applications in
255-
R)[https://rdrr.io/cran/ISLR/man/Smarket.html](https://rdrr.io/cran/ISLR/man/Smarket.html),
253+
"""
254+
Load S&P Stock Market dataset, as used in
255+
[An Introduction to Statistical Learning with applications in
256+
R](https://rdrr.io/cran/ISLR/man/Smarket.html),
256257
by Witten et al (2013), Springer-Verlag, New York."""
257258
macro load_smarket()
258259
quote

src/data/datasets_synthetic.jl

+13-9
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ Internal function to finalize the `make_*` functions.
2121
function finalize_Xy(X, y, shuffle, as_table, eltype, rng; clf::Bool=true)
2222
# Shuffle the rows if required
2323
if shuffle
24-
X, y = shuffle_rows(X, y; rng=rng)
25-
end
26-
if eltype != Float64
27-
X = convert.(eltype, X)
28-
end
29-
# return as matrix if as_table=false
24+
X, y = shuffle_rows(X, y; rng=rng)
25+
end
26+
if eltype != Float64
27+
X = convert.(eltype, X)
28+
end
29+
# return as matrix if as_table=false
3030
as_table || return X, y
3131
clf && return MLJBase.table(X), categorical(y)
3232
if length(size(y)) > 1
@@ -172,7 +172,6 @@ membership to the smaller or larger circle, respectively.
172172
* `noise=0`: standard deviation of the Gaussian noise added to the data,
173173
174174
* `factor=0.8`: ratio of the smaller radius over the larger one,
175-
176175
$(EXTRA_KW_MAKE*EXTRA_CLASSIFICATION)
177176
178177
### Example
@@ -318,7 +317,12 @@ Make portion `s` of vector `θ` exactly 0.
318317
"""
319318
sparsify!(rng, θ, s) =.*= (rand(rng, length(θ)) .< s))
320319

321-
"""Add outliers to portion s of vector."""
320+
"""
321+
outlify!(rng, y, s)
322+
323+
Add outliers to portion `s` of vector `y`.
324+
325+
"""
322326
outlify!(rng, y, s) =
323327
(n = length(y); y .+= 20 * randn(rng, n) .* (rand(rng, n) .< s))
324328

@@ -329,7 +333,7 @@ const SIGMOID_32 = log(Float32(1)/eps(Float32) - Float32(1))
329333
sigmoid(x)
330334
331335
Return the sigmoid computed in a numerically stable way:
332-
``σ(x) = 1/(1+exp(-x))``
336+
``σ(x) = 1/(1+\\exp(-x))``
333337
334338
"""
335339
function sigmoid(x::Float64)

src/hyperparam/one_dimensional_range_methods.jl

+25-25
Original file line numberDiff line numberDiff line change
@@ -66,31 +66,31 @@ In the first case iteration is over all `values` stored in the range
6666
iteration is over approximately `n` ordered values, generated as
6767
follows:
6868
69-
(i) First, exactly `n` values are generated between `U` and `L`, with a
70-
spacing determined by `r.scale` (uniform if `scale=:linear`) where `U`
71-
and `L` are given by the following table:
72-
73-
| `r.lower` | `r.upper` | `L` | `U` |
74-
|-------------|------------|---------------------|---------------------|
75-
| finite | finite | `r.lower` | `r.upper` |
76-
| `-Inf` | finite | `r.upper - 2r.unit` | `r.upper` |
77-
| finite | `Inf` | `r.lower` | `r.lower + 2r.unit` |
78-
| `-Inf` | `Inf` | `r.origin - r.unit` | `r.origin + r.unit` |
79-
80-
(ii) If a callable `f` is provided as `scale`, then a uniform spacing
81-
is always applied in (i) but `f` is broadcast over the results. (Unlike
82-
ordinary scales, this alters the effective range of values generated,
83-
instead of just altering the spacing.)
84-
85-
(iii) If `r` is a discrete numeric range (`r isa NumericRange{<:Integer}`)
86-
then the values are additionally rounded, with any duplicate values
87-
removed. Otherwise all the values are used (and there are exacltly `n`
88-
of them).
89-
90-
(iv) Finally, if a random number generator `rng` is specified, then the values are
91-
returned in random order (sampling without replacement), and otherwise
92-
they are returned in numeric order, or in the order provided to the
93-
range constructor, in the case of a `NominalRange`.
69+
1. First, exactly `n` values are generated between `U` and `L`, with a
70+
spacing determined by `r.scale` (uniform if `scale=:linear`) where `U`
71+
and `L` are given by the following table:
72+
73+
| `r.lower` | `r.upper` | `L` | `U` |
74+
|-------------|------------|---------------------|---------------------|
75+
| finite | finite | `r.lower` | `r.upper` |
76+
| `-Inf` | finite | `r.upper - 2r.unit` | `r.upper` |
77+
| finite | `Inf` | `r.lower` | `r.lower + 2r.unit` |
78+
| `-Inf` | `Inf` | `r.origin - r.unit` | `r.origin + r.unit` |
79+
80+
2. If a callable `f` is provided as `scale`, then a uniform spacing
81+
is always applied in (1) but `f` is broadcast over the results. (Unlike
82+
ordinary scales, this alters the effective range of values generated,
83+
instead of just altering the spacing.)
84+
85+
3. If `r` is a discrete numeric range (`r isa NumericRange{<:Integer}`)
86+
then the values are additionally rounded, with any duplicate values
87+
removed. Otherwise all the values are used (and there are exacltly `n`
88+
of them).
89+
90+
4. Finally, if a random number generator `rng` is specified, then the values are
91+
returned in random order (sampling without replacement), and otherwise
92+
they are returned in numeric order, or in the order provided to the
93+
range constructor, in the case of a `NominalRange`.
9494
9595
"""
9696
iterator(rng::AbstractRNG, r::ParamRange, args...) =

src/machines.jl

+31-31
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ err_missing_model(model) = ErrorException(
529529
)
530530

531531
"""
532-
last_model(mach::Machine)
532+
last_model(mach::Machine)
533533
534534
Return the last model used to train the machine `mach`. This is a bona fide model, even if
535535
`mach.model` is a symbol.
@@ -572,31 +572,31 @@ the true model given by `getproperty(composite, model)`. See also [`machine`](@r
572572
For the action to be a no-operation, either `mach.frozen == true` or
573573
or none of the following apply:
574574
575-
- (i) `mach` has never been trained (`mach.state == 0`).
575+
1. `mach` has never been trained (`mach.state == 0`).
576576
577-
- (ii) `force == true`.
577+
2. `force == true`.
578578
579-
- (iii) The `state` of some other machine on which `mach` depends has
580-
changed since the last time `mach` was trained (ie, the last time
581-
`mach.state` was last incremented).
579+
3. The `state` of some other machine on which `mach` depends has
580+
changed since the last time `mach` was trained (ie, the last time
581+
`mach.state` was last incremented).
582582
583-
- (iv) The specified `rows` have changed since the last retraining and
584-
`mach.model` does not have `Static` type.
583+
4. The specified `rows` have changed since the last retraining and
584+
`mach.model` does not have `Static` type.
585585
586-
- (v) `mach.model` is a model and different from the last model used for training, but has
587-
the same type.
586+
5. `mach.model` is a model and different from the last model used for training, but has
587+
the same type.
588588
589-
- (vi) `mach.model` is a model but has a type different from the last model used for
590-
training.
589+
6. `mach.model` is a model but has a type different from the last model used for
590+
training.
591591
592-
- (vii) `mach.model` is a symbol and `(composite, mach.model)` is different from the last
593-
model used for training, but has the same type.
592+
7. `mach.model` is a symbol and `(composite, mach.model)` is different from the last
593+
model used for training, but has the same type.
594594
595-
- (viii) `mach.model` is a symbol and `(composite, mach.model)` has a different type from
596-
the last model used for training.
595+
8. `mach.model` is a symbol and `(composite, mach.model)` has a different type from
596+
the last model used for training.
597597
598-
In any of the cases (i) - (iv), (vi), or (viii), `mach` is trained ab initio. If (v) or
599-
(vii) is true, then a training update is applied.
598+
In any of the cases (1) - (4), (6), or (8), `mach` is trained ab initio.
599+
If (5) or (7) is true, then a training update is applied.
600600
601601
To freeze or unfreeze `mach`, use `freeze!(mach)` or `thaw!(mach)`.
602602
@@ -658,7 +658,7 @@ function fit_only!(
658658
rows === nothing && (rows = (:))
659659
rows_is_new = !isdefined(mach, :old_rows) || rows != mach.old_rows
660660

661-
condition_iv = rows_is_new && !(mach.model isa Static)
661+
condition_4 = rows_is_new && !(mach.model isa Static)
662662

663663
upstream_has_changed = mach.old_upstream_state != upstream_state
664664

@@ -672,16 +672,16 @@ function fit_only!(
672672

673673
# build or update cached `resampled_data` if necessary (`mach.data` is already defined
674674
# above if needed here):
675-
if cache_data && (!data_is_valid || condition_iv)
675+
if cache_data && (!data_is_valid || condition_4)
676676
mach.resampled_data = selectrows(model, rows, mach.data...)
677677
end
678678

679679
# `fit`, `update`, or return untouched:
680-
if mach.state == 0 || # condition (i)
681-
force == true || # condition (ii)
682-
upstream_has_changed || # condition (iii)
683-
condition_iv || # condition (iv)
684-
modeltype_changed # conditions (vi) or (vii)
680+
if mach.state == 0 || # condition (1)
681+
force == true || # condition (2)
682+
upstream_has_changed || # condition (3)
683+
condition_4 || # condition (4)
684+
modeltype_changed # conditions (6) or (7)
685685

686686
isdefined(mach, :report) || (mach.report = LittleDict{Symbol,Any}())
687687

@@ -709,7 +709,7 @@ function fit_only!(
709709
rethrow()
710710
end
711711

712-
elseif model != mach.old_model # condition (v)
712+
elseif model != mach.old_model # condition (5)
713713

714714
# update the model:
715715
fitlog(mach, :update, verbosity)
@@ -1044,9 +1044,10 @@ To serialise using a different format, see [`serializable`](@ref).
10441044
Machines are deserialized using the `machine` constructor as shown in
10451045
the example below.
10461046
1047-
> The implementation of `save` for machines changed in MLJ 0.18
1048-
> (MLJBase 0.20). You can only restore a machine saved using older
1049-
> versions of MLJ using an older version.
1047+
!!! note
1048+
The implementation of `save` for machines changed in MLJ 0.18
1049+
(MLJBase 0.20). You can only restore a machine saved using older
1050+
versions of MLJ using an older version.
10501051
10511052
### Example
10521053
@@ -1073,8 +1074,7 @@ predict(predict_only_mach, X)
10731074
general purpose serialization formats, can allow for arbitrary code
10741075
execution during loading. This means it is possible for someone
10751076
to use a JLS file that looks like a serialized MLJ machine as a
1076-
[Trojan
1077-
horse](https://en.wikipedia.org/wiki/Trojan_horse_(computing)).
1077+
[Trojan horse](https://en.wikipedia.org/wiki/Trojan_horse_(computing)).
10781078
10791079
See also [`serializable`](@ref), [`machine`](@ref).
10801080

0 commit comments

Comments
 (0)