diff --git a/lib/scholar/neighbors/nn_descent.ex b/lib/scholar/neighbors/nn_descent.ex index 56729b11..c8e3f27b 100644 --- a/lib/scholar/neighbors/nn_descent.ex +++ b/lib/scholar/neighbors/nn_descent.ex @@ -102,13 +102,16 @@ defmodule Scholar.Neighbors.NNDescent do } """ deftransform fit(tensor, opts \\ []) do - opts = if Keyword.has_key?(opts, :max_candidates) do - opts - else - Keyword.put(opts, :max_candidates, min(60, opts[:num_neighbors])) - end + opts = + if Keyword.has_key?(opts, :max_candidates) do + opts + else + Keyword.put(opts, :max_candidates, min(60, opts[:num_neighbors])) + end + opts = NimbleOptions.validate!(opts, @opts_schema) sum_samples = Nx.axis_size(tensor, 0) + if opts[:num_neighbors] > sum_samples do raise ArgumentError, """ @@ -117,6 +120,7 @@ defmodule Scholar.Neighbors.NNDescent do #{sum_samples} """ end + if opts[:max_candidates] > sum_samples do raise ArgumentError, """ diff --git a/test/scholar/neighbors/nn_descent_test.exs b/test/scholar/neighbors/nn_descent_test.exs index a4887e97..eeeea48c 100644 --- a/test/scholar/neighbors/nn_descent_test.exs +++ b/test/scholar/neighbors/nn_descent_test.exs @@ -7,6 +7,7 @@ defmodule Scholar.Neighbors.NNDescentTest do key = Nx.Random.key(12) {tensor, key} = Nx.Random.uniform(key, shape: {10, 5}) size = Nx.axis_size(tensor, 0) + %NNDescent{nearest_neighbors: nearest_neighbors, distances: distances} = NNDescent.fit(tensor, num_neighbors: 1,