Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation updates #1077

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions docs/src/adding_models_for_general_use.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adding Models for General Use
# Adding Models for General Use

!!! note

Expand Down Expand Up @@ -975,7 +975,7 @@ also appears in the EvoTrees.jl package.
Here "user-supplied data" is what the MLJ user supplies when
constructing a machine, as in `machine(models, args...)`, which
coincides with the arguments expected by `fit(model, verbosity,
args...)` when `reformat` is not overloaded.
args...)` when `reformat` is not overloaded.

Overloading `reformat` is permitted for any `Model`
subtype, except for subtypes of `Static`. Here is a complete list of
Expand All @@ -992,7 +992,7 @@ responsibilities for such an implementation, for some
serving as a data front-end for operations like `predict`. It must
always hold that `reformat(model, args...)[1] = reformat(model,
args[1])`.

The fallback is `reformat(model, args...) = args` (i.e., slurps provided data).

*Important.* `reformat(model::SomeModelType, args...)` must always return a tuple, even if
Expand Down Expand Up @@ -1204,7 +1204,7 @@ Your document string must include the following components, in order:
- A closing *"See also"* sentence which includes a `@ref` link to the raw model type (if you are wrapping one).


## Unsupervised models
## Unsupervised

Unsupervised models implement the MLJ model interface in a very
similar fashion. The main differences are:
Expand All @@ -1214,28 +1214,30 @@ similar fashion. The main differences are:
although this is not a hard requirement. For example, a feature selection tool (wrapping
some supervised model) might also include a target `y` as input. Furthermore, in the
case of models that subtype `Static <: Unsupervised` (see also [Static
transformers](@ref) `fit` has no training arguments at all, but does not need to be
transformers](@ref)) `fit` has no training arguments at all, but does not need to be
implemented as a fallback returns `(nothing, nothing, nothing)`.

- A `transform` and/or `predict` method is implemented, and has the same signature as
`predict` does in the supervised case, as in `MLJModelInterface.transform(model,
fitresult, Xnew)`. However, it may only have one data argument `Xnew`, unless `model <:
Static`, in which case there is no restriction. A use-case for `predict` is K-means
`MLJModelInterface.predict(model, fitresult, Xnew)`. A use-case is
clustering that `predict`s labels and `transform`s
input features into a space of lower dimension. See [Transformers
that also predict](@ref) for an example.
Static`, in which case there is no restriction. A use-case for `predict` is K-means
clustering that `predict`s labels and `transform`s input features into a space of lower
dimension. See [Transformers that also predict](@ref) for an example.

- The `target_scitype` trait continues to refer to the output of `predict`, if
implemented, while a trait, `output_scitype`, is for the output of `transform`.
- The `target_scitype` refers to the output of `predict`, if implemented. A new trait,
`output_scitype`, is for the output of `transform`. Unless the model is `Static` (see
below) the trait `input_scitype` is for the single data argument of `transform` (and
`predict`, if implemented). If `fit` has more than one data argument, you must overload
the train `fit_data_scitype`, which bounds the allowed `data` passed to `fit(model,
verbosity, data...)` and will always be a `Tuple` type.

- An `inverse_transform` can be optionally implemented. The signature
is the same as `transform`, as in
`MLJModelInterface.inverse_transform(model, fitresult, Xout)`, which:
- must make sense for any `Xout` for which `scitype(Xout) <:
output_scitype(SomeSupervisedModel)` (see below); and
output_scitype(SomeSupervisedModel)` (see below); and
- must return an object `Xin` satisfying `scitype(Xin) <:
input_scitype(SomeSupervisedModel)`.
input_scitype(SomeSupervisedModel)`.

For sample implementatations, see MLJ's [built-in
transformers](https://github.com/JuliaAI/MLJModels.jl/blob/dev/src/builtins/Transformers.jl)
Expand All @@ -1245,8 +1247,16 @@ and the clustering models at

## Static models (models that do not generalize)

See [Static transformers](@ref) for basic implementation of models that do not generalize
to new data but do have hyperparameters.
A model type subtypes `Static <: Unsupervised` if it does not generalize to new data but
nevertheless has hyperparameters. See [Static transformers](@ref) for examples. In the
`Static` case, `transform` can have multiple arguments and `input_scitype` refers to the
allowed scitype of the slurped data, *even if there is only a single argument.* For
example, if the signature is `transform(static_model, X1, X2)`, then the allowed
`input_scitype` might be `Tuple{Table(Continuous), Table(Continuous)}`; if the signature
is `transform(static_model, X)`, the allowed `input_scitype` might be
`Tuple{Table(Continous)}`. The other traits are as for regular `Unsupervised` models, as
described above.


### Reporting byproducts of a static transformation

Expand Down
2 changes: 2 additions & 0 deletions src/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ loaded separately:
- MLJBalancing.jl: Incorporation of oversampling/undersampling methods in pipelines, via
the `BalancedModel` wrapper

- MLJFlow.jl: Integration with MLflow workflow tracking

- OpenML.jl: Tool for grabbing datasets from OpenML.org

"""
Expand Down
Loading