Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Jan 25, 2024
1 parent 0325126 commit 4c12ac0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
14 changes: 9 additions & 5 deletions lib/scholar/neighbors/nn_descent.ex
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,16 @@ defmodule Scholar.Neighbors.NNDescent do
}
"""
deftransform fit(tensor, opts \\ []) do
opts = if Keyword.has_key?(opts, :max_candidates) do
opts
else
Keyword.put(opts, :max_candidates, min(60, opts[:num_neighbors]))
end
opts =
if Keyword.has_key?(opts, :max_candidates) do
opts
else
Keyword.put(opts, :max_candidates, min(60, opts[:num_neighbors]))
end

opts = NimbleOptions.validate!(opts, @opts_schema)
sum_samples = Nx.axis_size(tensor, 0)

if opts[:num_neighbors] > sum_samples do
raise ArgumentError,
"""
Expand All @@ -117,6 +120,7 @@ defmodule Scholar.Neighbors.NNDescent do
#{sum_samples}
"""
end

if opts[:max_candidates] > sum_samples do
raise ArgumentError,
"""
Expand Down
1 change: 1 addition & 0 deletions test/scholar/neighbors/nn_descent_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ defmodule Scholar.Neighbors.NNDescentTest do
key = Nx.Random.key(12)
{tensor, key} = Nx.Random.uniform(key, shape: {10, 5})
size = Nx.axis_size(tensor, 0)

%NNDescent{nearest_neighbors: nearest_neighbors, distances: distances} =
NNDescent.fit(tensor,
num_neighbors: 1,
Expand Down

0 comments on commit 4c12ac0

Please sign in to comment.