Skip to content

Commit

Permalink
mix format
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed Jul 31, 2024
1 parent cf81723 commit 32b98fd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
4 changes: 3 additions & 1 deletion lib/scholar/linear/logistic_regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ defmodule Scholar.Linear.LogisticRegression do
iterations = opts[:iterations]
num_classes = opts[:num_classes]
optimizer_update_fn = opts[:optimizer_update_fn]

y_one_hot =
y
|> Nx.new_axis(1)
Expand All @@ -173,7 +174,8 @@ defmodule Scholar.Linear.LogisticRegression do
has_converged = Nx.sum(Nx.abs(loss)) < Nx.size(x) * opts[:eps]

{{coef, bias},
{x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged, iter + 1}}
{x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged,
iter + 1}}
end

%__MODULE__{
Expand Down
22 changes: 20 additions & 2 deletions lib/scholar/preprocessing/one_hot_encoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,19 @@ defmodule Scholar.Preprocessing.OneHotEncoder do
defn transform(%__MODULE__{ordinal_encoder: ordinal_encoder}, tensor) do
num_categories = Nx.size(ordinal_encoder.categories)
num_samples = Nx.size(tensor)

encoded =
ordinal_encoder
|> Scholar.Preprocessing.OrdinalEncoder.transform(tensor)
|> Nx.new_axis(1)
|> Nx.broadcast({num_samples, num_categories})

encoded == Nx.iota({num_samples, num_categories}, axis: 1)
end

@doc """
Apply encoding on the provided tensor directly. It's equivalent to `fit/2` and then `transform/2` on the same data.
Appl
encoding on the provided tensor directly. It's equivalent to `fit/2` and then `transform/2` on the same data.
## Examples
Expand All @@ -129,14 +132,29 @@ defmodule Scholar.Preprocessing.OneHotEncoder do
]
>
"""
defn fit_transform(tensor, opts) do
deftransform fit_transform(tensor, opts) do
if Nx.rank(tensor) != 1 do
raise ArgumentError,
"""
expected input tensor to have shape {num_samples}, \
got tensor with shape: #{inspect(Nx.shape(tensor))}
"""
end

opts = NimbleOptions.validate!(opts, @encode_schema)
fit_transform_n(tensor, opts)
end

defnp fit_transform_n(tensor, opts) do
num_samples = Nx.size(tensor)
num_categories = opts[:num_categories]

encoded =
tensor
|> Scholar.Preprocessing.OrdinalEncoder.fit_transform()
|> Nx.new_axis(1)
|> Nx.broadcast({num_samples, num_categories})

encoded == Nx.iota({num_samples, num_categories}, axis: 1)
end
end

0 comments on commit 32b98fd

Please sign in to comment.