Skip to content

Commit

Permalink
remove Nx.Type.merge/2
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed Apr 7, 2024
1 parent d8abd6e commit 29f984f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions lib/scholar/naive_bayes/multinomial.ex
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ defmodule Scholar.NaiveBayes.Multinomial do

opts = NimbleOptions.validate!(opts, @opts_schema)

x_type = to_float_type(x)
type = to_float_type(x)

{alpha, opts} = Keyword.pop!(opts, :alpha)
alpha = Nx.tensor(alpha, type: x_type)
alpha = Nx.tensor(alpha, type: type)

if Nx.shape(alpha) not in [{}, {num_features}] do
raise ArgumentError,
Expand Down Expand Up @@ -210,7 +210,7 @@ defmodule Scholar.NaiveBayes.Multinomial do
sample_weights_flag = opts[:sample_weights] != nil

{sample_weights, opts} = Keyword.pop(opts, :sample_weights, :nan)
sample_weights = Nx.tensor(sample_weights, type: x_type)
sample_weights = Nx.tensor(sample_weights, type: type)

if sample_weights_flag and Nx.shape(sample_weights) != {num_samples} do
raise ArgumentError,
Expand All @@ -223,6 +223,7 @@ defmodule Scholar.NaiveBayes.Multinomial do
opts =
opts ++
[
type: type,
priors_flag: priors_flag,
sample_weights_flag: sample_weights_flag
]
Expand All @@ -231,23 +232,22 @@ defmodule Scholar.NaiveBayes.Multinomial do
end

defnp fit_n(x, y, class_priors, sample_weights, alpha, opts) do
# TODO: Why not just to_float_type?
x_type = Nx.Type.merge(to_float_type(x), {:f, 32})
type = opts[:type]
num_samples = Nx.axis_size(x, 0)
num_classes = opts[:num_classes]

y_one_hot =
y
|> Nx.new_axis(1)
|> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1))
|> Nx.as_type(x_type)
|> Nx.as_type(type)

y_weighted =
if opts[:sample_weights_flag],
do: Nx.reshape(sample_weights, {num_samples, 1}) * y_one_hot,
else: y_one_hot

alpha_lower_bound = Nx.tensor(1.0e-10, type: x_type)
alpha_lower_bound = Nx.tensor(1.0e-10, type: type)

alpha =
if opts[:force_alpha], do: alpha, else: Nx.max(alpha, alpha_lower_bound)
Expand Down

0 comments on commit 29f984f

Please sign in to comment.