-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathsignatures.jl
381 lines (269 loc) · 10.9 KB
/
signatures.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
# a signature is just a thin wrapper for what the user knows as a "learning network
# interface"; see constant DOC_NETWORK_INTERFACES below for details.
# # HELPERS
"""
machines_given_model(node::AbstractNode)
**Private method.**
Return a dictionary of machines, keyed on model, for the all machines in the
completed learning network for which `node` is the greatest lower bound. Only
machines bound to symbolic models are included. Values are always vectors,
even if they contain only a single machine.
"""
function machines_given_model(node::AbstractNode)
ret = LittleDict{Symbol,Any}()
for mach in machines(node)
model = mach.model
model isa Symbol || continue
if !haskey(ret, model)
ret[model] = Any[mach,]
else
push!(ret[model], mach)
end
end
return ret
end
attempt_scalarize(v) = length(v) == 1 ? v[1] : v
"""
tuple_keyed_on_model(f, machines_given_model; scalarize=true, drop_nothings=true)
**Private method.**
Given a dictionary of machine vectors, keyed on model names (symbols), broadcast
`f` over each vector, and make the result, in the returned named tuple, the
value associated with the corresponding model name as key.
Singleton vector values are scalarized, unless `scalarize = false`.
If a value in the computed named tuple is `nothing`, or a vector of `nothing`s,
then the entry is dropped from the tuple, unless `drop_nothings=false`.
"""
function tuple_keyed_on_model(f, machines_given_model; scalarize=true, drop_nothings=true)
models = keys(machines_given_model) |> collect
named_tuple_values = map(models) do model
value = [f(m) for m in machines_given_model[model]]
scalarize && return attempt_scalarize(value)
return value
end
if drop_nothings
mask = map(named_tuple_values) do v
!(isnothing(v) || (v isa AbstractVector && eltype(v) === Nothing))
end |> collect
models = models[mask]
named_tuple_values = named_tuple_values[mask]
end
return NamedTuple{tuple(models...)}(tuple(named_tuple_values...))
end
const ERR_CALL_AND_COPY = ArgumentError(
"Expected something of `AbstractNode` type in a learning network interface "*
"but got something else. "
)
"""
call_and_copy(x)
**Private method.**
If `x` is an `AbstractNode`, then return a deep copy of `x()`. If `x` is a named tuple
`(k1=n1, k2=n2, ...)`, then "broadcast" `call_and_copy` over the values `n1`, `n2`, ...,
to get a new named tuple with the same keys.
"""
call_and_copy(::Any) = throw(ERR_CALL_AND_COPY)
call_and_copy(n::AbstractNode) = deepcopy(n())
function call_and_copy(nt::NamedTuple)
_keys = keys(nt)
_values = deepcopy(values(nt))
return NamedTuple{_keys}(call_and_copy.(_values))
end
# # DOC STRING
const DOC_NETWORK_INTERFACES =
"""
A *learning network interface* is a named tuple declaring certain interface points in
a learning network, to be used when "exporting" the network as a new stand-alone model
type. Examples are
(predict=yhat,)
(transform=Xsmall, acceleration=CPUThreads())
(predict=yhat, transform=W, report=(loss=loss_node,))
Here `yhat`, `Xsmall`, `W` and `loss_node` are nodes in the network.
The keys of the learning network interface always one of the following:
- The name of an operation, such as `:predict`, `:predict_mode`, `:transform`,
`:inverse_transform`. See "Operation keys" below.
- `:report`, for exposing results of calling a node *with no arguments* in the
composite model report. See "Including report nodes" below.
- `:fitted_params`, for exposing results of calling a node *with no arguments* as
fitted parameters of the composite model. See "Including fitted parameter nodes"
below.
- `:acceleration`, for articulating acceleration mode for training the network, e.g.,
`CPUThreads()`. Corresponding value must be an `AbstractResource`. If not included,
`CPU1()` is used.
### Operation keys
If the key is an operation, then the value must be a node `n` in the network with a
unique origin (`length(origins(n)) === 1`). The intention of a declaration such as
`predict=yhat` is that the exported model type implements `predict`, which, when
applied to new data `Xnew`, should return `yhat(Xnew)`.
#### Including report nodes
If the key is `:report`, then the corresponding value must be a named tuple
(k1=n1, k2=n2, ...)
whose values are all nodes. For each `k=n` pair, the key `k` will appear as a key in
the composite model report, with a corresponding value of `deepcopy(n())`, called
immediatately after training or updating the network. For examples, refer to the
"Learning Networks" section of the MLJ manual.
#### Including fitted parameter nodes
If the key is `:fitted_params`, then the behaviour is as for report nodes but results
are exposed as fitted parameters of the composite model instead of the report.
"""
# # SIGNATURES
"""
Signature(interface::NamedTuple)
**Private type.**
Return a thinly wrapped version of a learning network interface (defined below). Unwrap
with `MLJBase.unwrap`:
```julia
interface = (predict=source(), report=(loss=source(),))
signature = MLJBase.Signature(interface)
@assert MLJBase.unwrap(signature) === interface
```
$DOC_NETWORK_INTERFACES
"""
struct Signature{S<:NamedTuple}
interface::S
end
unwrap(signature::Signature) = signature.interface
# # METHODS
"""
operation_nodes(signature)
**Private method.**
Return the operation nodes of `signature`, as a named tuple keyed on operation names.
See also [`MLJBase.Signature`](@ref).
"""
function operation_nodes(signature::Signature)
interface = unwrap(signature)
ops = filter(in(OPERATIONS), keys(interface))
return NamedTuple{ops}(map(op->getproperty(interface, op), ops))
end
"""
report_nodes(signature)
**Private method.**
Return the report nodes of `signature`, as a named tuple.
See also [`MLJBase.Signature`](@ref).
"""
function report_nodes(signature::Signature)
interface = unwrap(signature)
:report in keys(interface) || return NamedTuple()
return interface.report
end
"""
fitted_params_nodes(signature)
**Private method.**
Return the fitted parameter nodes of `signature`, as a named tuple.
See also [`MLJBase.Signature`](@ref).
"""
function fitted_params_nodes(signature::Signature)
interface = unwrap(signature)
:fitted_params in keys(interface) || return NamedTuple()
return interface.fitted_params
end
"""
acceleration(signature)
**Private method.**
Return the acceleration mode of `signature`.
See also [`MLJBase.Signature`](@ref).
"""
function acceleration(signature::Signature)
interface = unwrap(signature)
:acceleration in keys(interface) || return CPU1()
return interface.acceleration
end
"""
operations(signature)
**Private method.**
Return the names of all operations in `signature`.
See also [`MLJBase.Signature`](@ref).
"""
operations(signature::Signature) = keys(operation_nodes(signature))
"""
glb(signature::Signature)
**Private method.**
Return the greatest lower bound of all operation nodes, report nodes and fitted parameter
nodes associated with `signature`.
See also [`MLJBase.Signature`](@ref).
"""
function glb(signature::Signature)
grab(f) = values(f(signature)) |> collect
nodes = vcat(
grab(operation_nodes),
grab(report_nodes),
grab(fitted_params_nodes),
)
return glb(nodes...)
end
"""
age(signature::Signature)
**Private method.**
Return the sum of the ages of all machines in the underlying network of `signature`.
See also [`MLJBase.Signature`](@ref).
"""
age(signature::Signature) = sum(age, machines(glb(signature)))
"""
report_supplement(signature)
**Private method.**
Generate a deep copy of the supplementary report defined by the signature (that part of
the composite model report coming from report nodes in the signature). This is a named
tuple.
See also [`MLJBase.Signature`](@ref).
"""
report_supplement(signature::Signature) = call_and_copy(report_nodes(signature))
"""
fitted_params_supplement(signature)
**Private method.**
Generate a deep copy of the supplementary fitted parameters defined by the signature (that
part of the composite model fitted parameters coming from fitted parameter nodes in the
signature). This is a named tuple.
See also [`MLJBase.Signature`](@ref).
"""
fitted_params_supplement(signature::Signature) = call_and_copy(fitted_params_nodes(signature))
"""
report(signature; supplement=true)
**Private method.**
Generate a report for the learning network associated with `signature`, including the
supplementary report.
Suppress calling of the report nodes of `signature`, and excluded their contribution to
the output, by specifying `supplement=false`.
See also [`MLJBase.report_supplement`](@ref).
See also [`MLJBase.Signature`](@ref).
"""
function report(signature::Signature; supplement=true)
greatest_lower_bound = glb(signature)
supplement_report = supplement ? MLJBase.report_supplement(signature) : NamedTuple()
d = MLJBase.machines_given_model(greatest_lower_bound)
internal_report = MLJBase.tuple_keyed_on_model(report, d)
merge(internal_report, supplement_report)
end
"""
fitted_params(signature; supplement=true)
**Private method.**
Generate a fitted_params for the learning network associated with `signature`, including
the supplementary fitted_params.
Suppress calling of the fitted_params nodes of `signature`, and excluded their
contribution to the output, by specifying `supplement=false`.
See also [`MLJBase.fitted_params_supplement`](@ref).
See also [`MLJBase.Signature`](@ref).
"""
function fitted_params(signature::Signature; supplement=true)
greatest_lower_bound = glb(signature)
supplement_fitted_params =
supplement ? MLJBase.fitted_params_supplement(signature) : NamedTuple()
d = MLJBase.machines_given_model(greatest_lower_bound)
internal_fitted_params = MLJBase.tuple_keyed_on_model(fitted_params, d)
merge(internal_fitted_params, supplement_fitted_params)
end
"""
output_and_report(signature, operation, Xnew...)
**Private method.**
Duplicate `signature` and return appropriate output for the specified `operation` (a key
of `signature`) applied to the duplicate, together with the operational report. Report
nodes of `signature` are not called, and they make no contribution to that report.
Return value has the form `(output, report)`.
See also [`MLJBase.Signature`](@ref).
"""
function output_and_report(signature, operation, Xnew)
signature_clone = replace(signature, copy_unspecified_deeply=false)
output = getproperty(MLJBase.unwrap(signature_clone), operation)(Xnew)
report = MLJBase.report(signature_clone; supplement=false)
return output, report
end
# special case for static transformers with multiple inputs:
output_and_report(signature, operation, Xnew...) =
output_and_report(signature, operation, Xnew)