diff --git a/test/scholar/neighbors/brute_knn_test.exs b/test/scholar/neighbors/brute_knn_test.exs index 052f02bb..1cac3d35 100644 --- a/test/scholar/neighbors/brute_knn_test.exs +++ b/test/scholar/neighbors/brute_knn_test.exs @@ -119,5 +119,10 @@ defmodule Scholar.Neighbors.BruteKNNTest do assert distances_pred == distances_true end + + test "custom metric" do + model = BruteKNN.fit(data(), num_neighbors: 3, batch_size: 1, metric: :cosine) + assert {_, _} = BruteKNN.predict(model, query()) + end end end