Skip to content

Commit 5739a73

Browse files
authored
Merge pull request #982 from JuliaAI/dev
For a 1.5.0 release
2 parents ffe0ac2 + 370b3da commit 5739a73

File tree

6 files changed

+104
-7
lines changed

6 files changed

+104
-7
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "1.4.0"
4+
version = "1.5.0"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/MLJBase.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ export params
248248
# -------------------------------------------------------------------
249249
# exports from this module, MLJBase
250250

251-
# computational_resources.jl:
251+
# get/set global constants:
252+
export default_logger
252253
export default_resource
253254

254255
# one_dimensional_ranges.jl:

src/init.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ function __init__()
33
global DEFAULT_RESOURCE = Ref{AbstractResource}(CPU1())
44
global DEFAULT_SCITYPE_CHECK_LEVEL = Ref{Int}(1)
55
global SHOW_COLOR = Ref{Bool}(true)
6+
global DEFAULT_LOGGER = Ref{Any}(nothing)
67

78
# for testing asynchronous training of learning networks:
89
global TESTING = parse(Bool, get(ENV, "TEST_MLJBASE", "false"))

src/machines.jl

+19
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,25 @@ function save(file::Union{String,IO}, mach::Machine)
10881088
serialize(file, smach)
10891089
end
10901090

1091+
const ERR_INVALID_DEFAULT_LOGGER = ArgumentError(
1092+
"You have attempted to save a machine to the default logger "*
1093+
"but `default_logger()` is currently `nothing`. "*
1094+
"Either specify an explicit logger, path or stream to save to, "*
1095+
"or use `default_logger(logger)` "*
1096+
"to change the default logger. "
1097+
)
1098+
1099+
"""
1100+
MLJ.save(mach)
1101+
MLJBase.save(mach)
1102+
1103+
Save the current machine as an artifact at the location associated with
1104+
`default_logger`](@ref).
1105+
1106+
"""
1107+
MLJBase.save(mach::Machine) = MLJBase.save(default_logger(), mach)
1108+
MLJBase.save(::Nothing, ::Machine) = throw(ERR_INVALID_DEFAULT_LOGGER)
1109+
10911110
report_for_serialization(mach) = mach.report
10921111

10931112
# NOTE. there is also a specialization of `report_for_serialization` for `Composite`

src/resampling.jl

+64-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# # TYPE ALIASES
1+
# TYPE ALIASES
22

33
const AbstractRow = Union{AbstractVector{<:Integer}, Colon}
44
const TrainTestPair = Tuple{AbstractRow,AbstractRow}
@@ -747,6 +747,64 @@ Base.show(io::IO, e::CompactPerformanceEvaluation) =
747747
print(io, "CompactPerformanceEvaluation$(_summary(e))")
748748

749749

750+
751+
# ===============================================================
752+
## USER CONTROL OF DEFAULT LOGGING
753+
754+
const DOC_DEFAULT_LOGGER =
755+
"""
756+
757+
The default logger is used in calls to [`evaluate!`](@ref) and [`evaluate`](@ref), and
758+
in the constructors `TunedModel` and `IteratedModel`, unless the `logger` keyword is
759+
explicitly specified.
760+
761+
!!! note
762+
763+
Prior to MLJ v0.20.7 (and MLJBase 1.5) the default logger was always `nothing`.
764+
765+
"""
766+
767+
"""
768+
default_logger()
769+
770+
Return the current value of the default logger for use with supported machine learning
771+
tracking platforms, such as [MLflow](https://mlflow.org/docs/latest/index.html).
772+
773+
$DOC_DEFAULT_LOGGER
774+
775+
When MLJBase is first loaded, the default logger is `nothing`. To reset the logger, see
776+
beow.
777+
778+
"""
779+
default_logger() = DEFAULT_LOGGER[]
780+
781+
"""
782+
default_logger(logger)
783+
784+
Reset the default logger.
785+
786+
# Example
787+
788+
Suppose an [MLflow](https://mlflow.org/docs/latest/index.html) tracking service is running
789+
on a local server at `http://127.0.0.1:500`. Then every in every `evaluate` call in which
790+
`logger` is not specified, as in the example below, the peformance evaluation is
791+
automatically logged to the service.
792+
793+
```julia-repl
794+
using MLJ
795+
logger = MLJFlow.Logger("http://127.0.0.1:5000/api")
796+
default_logger(logger)
797+
798+
X, y = make_moons()
799+
model = ConstantClassifier()
800+
evaluate(model, X, y, measures=[log_loss, accuracy)])
801+
802+
"""
803+
function default_logger(logger)
804+
DEFAULT_LOGGER[] = logger
805+
end
806+
807+
750808
# ===============================================================
751809
## EVALUATION METHODS
752810

@@ -1068,7 +1126,8 @@ Although `evaluate!` is mutating, `mach.model` and `mach.args` are not mutated.
10681126
`false` the `per_observation` field of the returned object is populated with
10691127
`missing`s. Setting to `false` may reduce compute time and allocations.
10701128
1071-
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
1129+
- `logger=default_logger()` - a logger object for forwarding results to a machine learning
1130+
tracking platform; see [`default_logger`](@ref) for details.
10721131
10731132
- `compact=false` - if `true`, the returned evaluation object excludes these fields:
10741133
`fitted_params_per_fold`, `report_per_fold`, `train_test_rows`.
@@ -1093,7 +1152,7 @@ function evaluate!(
10931152
check_measure=true,
10941153
per_observation=true,
10951154
verbosity=1,
1096-
logger=nothing,
1155+
logger=default_logger(),
10971156
compact=false,
10981157
)
10991158

@@ -1544,7 +1603,7 @@ end
15441603
acceleration=default_resource(),
15451604
check_measure=true,
15461605
per_observation=true,
1547-
logger=nothing,
1606+
logger=default_logger(),
15481607
compact=false,
15491608
)
15501609
@@ -1624,7 +1683,7 @@ function Resampler(
16241683
repeats=1,
16251684
cache=true,
16261685
per_observation=true,
1627-
logger=nothing,
1686+
logger=default_logger(),
16281687
compact=false,
16291688
)
16301689
resampler = Resampler(

test/resampling.jl

+17
Original file line numberDiff line numberDiff line change
@@ -935,4 +935,21 @@ end
935935
end
936936
end
937937

938+
# DUMMY LOGGER
939+
940+
struct DummyLogger end
941+
942+
MLJBase.save(logger::DummyLogger, mach::Machine) = mach.model
943+
944+
@testset "default logger" begin
945+
@test isnothing(default_logger())
946+
model = ConstantClassifier()
947+
mach = machine(model, make_moons(10)...)
948+
fit!(mach, verbosity=0)
949+
@test_throws MLJBase.ERR_INVALID_DEFAULT_LOGGER MLJBase.save(mach)
950+
default_logger(DummyLogger())
951+
@test default_logger() == DummyLogger()
952+
@test MLJBase.save(mach) == model
953+
end
954+
938955
true

0 commit comments

Comments
 (0)