diff --git a/lib/scholar/metrics/neighbors.ex b/lib/scholar/metrics/neighbors.ex index ee0841cb..bea1d90f 100644 --- a/lib/scholar/metrics/neighbors.ex +++ b/lib/scholar/metrics/neighbors.ex @@ -1,37 +1,41 @@ defmodule Scholar.Metrics.Neighbors do import Nx.Defn - deftransform recall(graph_true, graph_pred) do - if Nx.rank(graph_true) != 2 do + defn recall(neighbors_true, neighbors_pred) do + if Nx.rank(neighbors_true) != 2 do raise ArgumentError, """ expected true neighbors to have shape {num_samples, num_neighbors}, \ - got tensor with shape: #{inspect(Nx.shape(graph_true))}\ + got tensor with shape: #{inspect(Nx.shape(neighbors_true))}\ """ end - if Nx.rank(graph_pred) != 2 do + if Nx.rank(neighbors_pred) != 2 do raise ArgumentError, """ expected predicted neighbors to have shape {num_samples, num_neighbors}, \ - got tensor with shape: #{inspect(Nx.shape(graph_pred))}\ + got tensor with shape: #{inspect(Nx.shape(neighbors_pred))}\ """ end - if Nx.shape(graph_true) != Nx.shape(graph_pred) do + if Nx.axis_size(neighbors_true, 0) != Nx.axis_size(neighbors_pred, 0) do raise ArgumentError, """ - expected true and predicted neighbors to have the same shape, \ - got #{inspect(Nx.shape(graph_true))} and #{inspect(Nx.shape(graph_pred))}\ + expected true and predicted neighbors to have the same axis 0 size, \ + got #{inspect(Nx.axis_size(neighbors_true, 0))} and #{inspect(Nx.axis_size(neighbors_pred, 0))}\ """ end - recall_n(graph_true, graph_pred) - end + if Nx.axis_size(neighbors_true, 1) != Nx.axis_size(neighbors_pred, 1) do + raise ArgumentError, + """ + expected true and predicted neighbors to have the same axis 1 size, \ + got #{inspect(Nx.axis_size(neighbors_true, 1))} and #{inspect(Nx.axis_size(neighbors_pred, 1))}\ + """ + end - defn recall_n(graph_true, graph_pred) do - {n, k} = Nx.shape(graph_true) - concatenated = Nx.concatenate([graph_true, graph_pred], axis: 1) |> Nx.sort(axis: 1) + {n, k} = Nx.shape(neighbors_true) + concatenated = Nx.concatenate([neighbors_true, neighbors_pred], axis: 1) |> Nx.sort(axis: 1) duplicate_mask = concatenated[[.., 0..(2 * k - 2)]] == concatenated[[.., 1..(2 * k - 1)]] duplicate_mask |> Nx.sum() |> Nx.divide(n * k) end