Skip to content

Commit

Permalink
move recall completely within defn
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed Jan 19, 2024
1 parent e6917aa commit 6b7ccf8
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions lib/scholar/metrics/neighbors.ex
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
defmodule Scholar.Metrics.Neighbors do
import Nx.Defn

deftransform recall(graph_true, graph_pred) do
if Nx.rank(graph_true) != 2 do
defn recall(neighbors_true, neighbors_pred) do
if Nx.rank(neighbors_true) != 2 do
raise ArgumentError,
"""
expected true neighbors to have shape {num_samples, num_neighbors}, \
got tensor with shape: #{inspect(Nx.shape(graph_true))}\
got tensor with shape: #{inspect(Nx.shape(neighbors_true))}\
"""
end

if Nx.rank(graph_pred) != 2 do
if Nx.rank(neighbors_pred) != 2 do
raise ArgumentError,
"""
expected predicted neighbors to have shape {num_samples, num_neighbors}, \
got tensor with shape: #{inspect(Nx.shape(graph_pred))}\
got tensor with shape: #{inspect(Nx.shape(neighbors_pred))}\
"""
end

if Nx.shape(graph_true) != Nx.shape(graph_pred) do
if Nx.axis_size(neighbors_true, 0) != Nx.axis_size(neighbors_pred, 0) do
raise ArgumentError,
"""
expected true and predicted neighbors to have the same shape, \
got #{inspect(Nx.shape(graph_true))} and #{inspect(Nx.shape(graph_pred))}\
expected true and predicted neighbors to have the same axis 0 size, \
got #{inspect(Nx.axis_size(neighbors_true, 0))} and #{inspect(Nx.axis_size(neighbors_pred, 0))}\
"""
end

recall_n(graph_true, graph_pred)
end
if Nx.axis_size(neighbors_true, 1) != Nx.axis_size(neighbors_pred, 1) do
raise ArgumentError,
"""
expected true and predicted neighbors to have the same axis 1 size, \
got #{inspect(Nx.axis_size(neighbors_true, 1))} and #{inspect(Nx.axis_size(neighbors_pred, 1))}\
"""
end

defn recall_n(graph_true, graph_pred) do
{n, k} = Nx.shape(graph_true)
concatenated = Nx.concatenate([graph_true, graph_pred], axis: 1) |> Nx.sort(axis: 1)
{n, k} = Nx.shape(neighbors_true)
concatenated = Nx.concatenate([neighbors_true, neighbors_pred], axis: 1) |> Nx.sort(axis: 1)
duplicate_mask = concatenated[[.., 0..(2 * k - 2)]] == concatenated[[.., 1..(2 * k - 1)]]
duplicate_mask |> Nx.sum() |> Nx.divide(n * k)
end
Expand Down

0 comments on commit 6b7ccf8

Please sign in to comment.