Skip to content


K-NN Classifier (#263)
Browse files Browse the repository at this point in the history
* Major update, submitting a PR

* Update doc

* Update doc

* Add distance to KDTree.predict/2

* Update doc

* Update doc

* Add metric to RandomProjectionForest and LargeVis, more unit-tests, etc

* Add predict_proba/2

* Rename predict_proba to predict_probability, fix a bug inside of it

* Remove Nx.Type.merge in predict_probability

Co-authored-by: José Valim <>


Co-authored-by: Krsto Proroković <>
Co-authored-by: José Valim <>
  • Loading branch information
3 people authored May 14, 2024
1 parent 96c4e5b commit ffaac87
Show file tree
Hide file tree
Showing 8 changed files with 583 additions and 27 deletions.
2 changes: 1 addition & 1 deletion lib/scholar/neighbors/brute_knn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ defmodule Scholar.Neighbors.BruteKNN do
type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]},
default: {:minkowski, 2},
doc: ~S"""
The function that measures distance between two points. Possible values:
The function that measures the distance between two points. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
Expand Down
15 changes: 12 additions & 3 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,20 @@ defmodule Scholar.Neighbors.KDTree do
deftransform predict(tree, data) do
if Nx.rank(data) != 2 do
raise ArgumentError, "Input data must be a 2D tensor"
raise ArgumentError,
expected query tensor to have shape {num_queries, num_features}, \
got tensor with shape: #{inspect(Nx.shape(data))}

if Nx.axis_size(data, -1) != Nx.axis_size(, -1) do
raise ArgumentError, "Input data must have the same number of features as the training data"
if Nx.axis_size(, 1) != Nx.axis_size(data, 1) do
raise ArgumentError,
expected query tensor to have same number of features as tensor used to fit the tree, \
got #{inspect(Nx.axis_size(data, 1))} \
and #{inspect(Nx.axis_size(, 1))}

predict_n(tree, data)
Expand Down
287 changes: 287 additions & 0 deletions lib/scholar/neighbors/knn_classifier.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
defmodule Scholar.Neighbors.KNNClassifier do
@moduledoc """
K-Nearest Neighbors Classifier.
Performs classifiction by computing the (weighted) majority voting among k-nearest neighbors.

import Nx.Defn
import Scholar.Shared
require Nx

@derive {Nx.Container, keep: [:num_classes, :weights], containers: [:algorithm, :labels]}
defstruct [:algorithm, :num_classes, :weights, :labels]

opts = [
algorithm: [
type: :atom,
default: :brute,
doc: """
Algorithm used to compute the k-nearest neighbors. Possible values:
* `:brute` - Brute-force search. See `Scholar.Neighbors.BruteKNN` for more details.
* `:kd_tree` - k-d tree. See `Scholar.Neighbors.KDTree` for more details.
* `:random_projection_forest` - Random projection forest. See `Scholar.Neighbors.RandomProjectionForest` for more details.
* Module implementing `fit(data, opts)` and `predict(model, query)`. predict/2 must return tuple containing indices
of k-nearest neighbors of query points as well as distances between query points and their k-nearest neighbors.
num_classes: [
required: true,
type: :pos_integer,
doc: "The number of possible classes."
weights: [
type: {:in, [:uniform, :distance]},
default: :uniform,
doc: """
Weight function used in prediction. Possible values:
* `:uniform` - uniform weights. All points in each neighborhood are weighted equally.
* `:distance` - weight points by the inverse of their distance. in this case, closer neighbors of
a query point will have a greater influence than neighbors which are further away.


@doc """
Fits a k-NN classifier model.
## Options
Algorithm-specific options (e.g. `:num_neighbors`, `:metric`) should be provided together with the classifier options.
## Examples
iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y = Nx.tensor([0, 0, 0, 1, 1])
iex> model =, y, num_neighbors: 3, num_classes: 2)
iex> model.algorithm, num_neighbors: 3)
iex> model.labels
Nx.tensor([0, 0, 0, 1, 1])
iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y = Nx.tensor([0, 0, 0, 1, 1])
iex> model =, y, algorithm: :kd_tree, num_neighbors: 3, metric: {:minkowski, 1}, num_classes: 2)
iex> model.algorithm, num_neighbors: 3, metric: {:minkowski, 1})
iex> model.labels
Nx.tensor([0, 0, 0, 1, 1])
iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y = Nx.tensor([0, 0, 0, 1, 1])
iex> key = Nx.Random.key(12)
iex> model =, y, algorithm: :random_projection_forest, num_neighbors: 2, num_classes: 2, num_trees: 4, key: key)
iex> model.algorithm, num_neighbors: 2, num_trees: 4, key: key)
iex> model.labels
Nx.tensor([0, 0, 0, 1, 1])
deftransform fit(x, y, opts) do
if Nx.rank(x) != 2 do
raise ArgumentError,
expected x to have shape {num_samples, num_features}, \
got tensor with shape: #{inspect(Nx.shape(x))}

if Nx.rank(y) != 1 do
raise ArgumentError,
expected y to have shape {num_samples}, \
got tensor with shape: #{inspect(Nx.shape(y))}

if Nx.axis_size(x, 0) != Nx.axis_size(y, 0) do
raise ArgumentError,
expected x and y to have the same first dimension, \
got #{Nx.axis_size(x, 0)} and #{Nx.axis_size(y, 0)}

{opts, algorithm_opts} = Keyword.split(opts, [:algorithm, :num_classes, :weights])
opts = NimbleOptions.validate!(opts, @opts_schema)

algorithm_module =
case opts[:algorithm] do
:brute ->

:kd_tree ->

:random_projection_forest ->

module when is_atom(module) ->

algorithm =, algorithm_opts)

algorithm: algorithm,
num_classes: opts[:num_classes],
labels: y,
weights: opts[:weights]

@doc """
Predicts classes using a k-NN classifier model.
## Examples
iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y_train = Nx.tensor([0, 0, 0, 1, 1])
iex> model =, y_train, num_neighbors: 3, num_classes: 2)
iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> Scholar.Neighbors.KNNClassifier.predict(model, x)
Nx.tensor([0, 0, 1])
defn predict(model, x) do
{neighbors, distances} = compute_knn(model.algorithm, x)
neighbor_labels = Nx.take(model.labels, neighbors)

case model.weights do
:uniform -> Nx.mode(neighbor_labels, axis: 1)
:distance -> weighted_mode(neighbor_labels, check_weights(distances))

@doc """
Predicts class probabilities using a k-NN classifier model.
## Examples
iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y_train = Nx.tensor([0, 0, 0, 1, 1])
iex> model =, y_train, num_neighbors: 3, num_classes: 2)
iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> Scholar.Neighbors.KNNClassifier.predict_probability(model, x)
[1.0, 0.0],
[1.0, 0.0],
[0.3333333432674408, 0.6666666865348816]
defn predict_probability(model, x) do
num_samples = Nx.axis_size(x, 0)
type = to_float_type(x)
{neighbors, distances} = compute_knn(model.algorithm, x)
neighbor_labels = Nx.take(model.labels, neighbors)
proba = Nx.broadcast(Nx.tensor(0.0, type: type), {num_samples, model.num_classes})

weights =
case model.weights do
:uniform -> Nx.broadcast(1.0, neighbors)
:distance -> check_weights(distances)

indices =
[Nx.iota(Nx.shape(neighbor_labels), axis: 0), neighbor_labels],
axis: 2
|> Nx.flatten(axes: [0, 1])

proba = Nx.indexed_add(proba, indices, Nx.flatten(weights))
normalizer = Nx.sum(proba, axes: [1])
normalizer = == 0, 1, normalizer)
proba / Nx.new_axis(normalizer, 1)

deftransformp compute_knn(algorithm, x) do
algorithm.__struct__.predict(algorithm, x)

defnp check_weights(weights) do
zero_mask = weights == 0
zero_rows = zero_mask |> Nx.any(axes: [1], keep_axes: true) |> Nx.broadcast(weights)
weights =, 1, weights)
weights_inv = 1 / weights,, 1, 0), weights_inv)

defnp weighted_mode(tensor, weights) do
tensor_size = Nx.size(tensor)

cond do
tensor_size == 1 ->
Nx.squeeze(tensor, axes: [1])

true ->
weighted_mode_general(tensor, weights)

defnp weighted_mode_general(tensor, weights) do
{num_samples, num_features} = tensor_shape = Nx.shape(tensor)

indices = Nx.argsort(tensor, axis: 1)

sorted = Nx.take_along_axis(tensor, indices, axis: 1)

size_to_broadcast = {num_samples, 1}

group_indices =
Nx.broadcast(0, size_to_broadcast),
Nx.slice_along_axis(sorted, 0, Nx.axis_size(sorted, 1) - 1, axis: 1),
Nx.slice_along_axis(sorted, 1, Nx.axis_size(sorted, 1) - 1, axis: 1)
axis: 1
|> Nx.cumulative_sum(axis: 1)

num_elements = Nx.size(tensor_shape)

counting_indices =
|> Nx.iota(axis: 0)
|> Nx.reshape({num_elements, 1}),
Nx.reshape(group_indices, {num_elements, 1})
|> Nx.concatenate(axis: 1)

to_add = Nx.flatten(weights)

indices =
(indices + num_features * Nx.iota(tensor_shape, axis: 0))
|> Nx.flatten()

weights = Nx.take(to_add, indices)

largest_group_indices =
Nx.broadcast(0, sorted)
|> Nx.indexed_add(counting_indices, weights)
|> Nx.argmax(axis: 1, keep_axis: true)

indices =
|> Nx.broadcast(group_indices)
|> Nx.equal(group_indices)
|> Nx.argmax(axis: 1, keep_axis: true)

res = Nx.take_along_axis(sorted, indices, axis: 1)
Nx.squeeze(res, axes: [1])

0 comments on commit ffaac87

Please sign in to comment.