Skip to content

Commit

Permalink
Add roc_auc_score + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuvalniy committed Feb 13, 2024
1 parent 90cf485 commit b3c76a6
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 1 deletion.
Binary file modified src/metrics/__pycache__/classification.cpython-311.pyc
Binary file not shown.
19 changes: 18 additions & 1 deletion src/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def f1_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
def roc_curve(y_true: np.ndarray, y_pred: np.ndarray) -> (np.ndarray, np.ndarray):
"""
Calculate ROC curve values.
:param y_true: Target labels (n_samples, ).
:param y_pred: Predictions probability (n_samples, ).
:return:
Expand Down Expand Up @@ -130,3 +129,21 @@ def roc_curve(y_true: np.ndarray, y_pred: np.ndarray) -> (np.ndarray, np.ndarray
fpr[i] = fp / (tn + fp)

return tpr, fpr, thresholds


def roc_auc_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Calculate ROC-AUC score using trapezoidal rule.
:param y_true: Target labels (n_samples, ).
:param y_pred: Predictions probability (n_samples, ).
:return: ROC-AUC score.
"""

tpr, fpr, _ = roc_curve(y_true, y_pred)

shifted_tpr = np.roll(tpr, shift=1)
shifted_tpr[0] = 0

height = np.diff(fpr, n=1, prepend=0)
auc = np.sum((tpr + shifted_tpr) / 2 * height)
return auc.item()
Binary file not shown.
53 changes: 53 additions & 0 deletions tests/metrics/classifciation/test_roc_auc_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np

from src.metrics import roc_auc_score


def test_roc_auc_identical_labels():
y_true = np.array([1, 0, 1, 0, 1])
y_pred = np.array([1, 0, 1, 0, 1])

expected_roc_auc = 1.0
roc_auc = roc_auc_score(y_true, y_pred)

assert np.isclose(expected_roc_auc, roc_auc, atol=1e-5, rtol=1e-5)


def test_roc_auc_reversed_labels():
y_true = np.array([1, 0, 1, 0, 1])
y_pred = np.array([0, 1, 0, 1, 0])

expected_roc_auc = 0.0
roc_auc = roc_auc_score(y_true, y_pred)

assert np.isclose(expected_roc_auc, roc_auc, atol=1e-5, rtol=1e-5)


def test_roc_auc_all_true():
y_true = np.array([1, 0, 0, 0, 1])
y_pred = np.array([1, 1, 1, 1, 1])

expected_roc_auc = 0.5
roc_auc = roc_auc_score(y_true, y_pred)

np.isclose(expected_roc_auc, roc_auc, atol=1e-5, rtol=1e-5)


def test_roc_auc_equal_prob():
y_true = np.array([1, 0, 0, 0, 1])
y_pred = np.array([0.5, 0.5, 0.5, 0.5, 0.5])

expected_roc_auc = 0.5
roc_auc = roc_auc_score(y_true, y_pred)

np.isclose(expected_roc_auc, roc_auc, atol=1e-5, rtol=1e-5)


def test_roc_auc_close_to_target():
y_true = np.array([1, 0, 0, 0, 1])
y_pred = np.array([0.9, 0.1, 0.1, 0.1, 0.9])

expected_roc_auc = 1.0
roc_auc = roc_auc_score(y_true, y_pred)

np.isclose(expected_roc_auc, roc_auc, atol=1e-5, rtol=1e-5)

0 comments on commit b3c76a6

Please sign in to comment.