Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Replace prts metrics #2400

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
301641d
Pre-commit fixes
aryanpola Nov 26, 2024
c4db216
Merge branch 'aeon-toolkit:main' into recall
aryanpola Nov 26, 2024
cc1101a
Position parameter in calculate_bias
aryanpola Nov 26, 2024
31ed73d
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Nov 26, 2024
1028942
Added recall metric
aryanpola Nov 30, 2024
d4dc5ca
merged into into one file
aryanpola Dec 3, 2024
4db8027
test added
aryanpola Dec 20, 2024
4baaec7
Merge branch 'main' into recall
aryanpola Dec 20, 2024
43cd9ac
Merge branch 'main' into recall
aryanpola Dec 23, 2024
c098731
Changes in test and range_metrics
aryanpola Dec 29, 2024
497362f
list of list running but error!
aryanpola Dec 29, 2024
ab87680
flattening lists, all cases passed
aryanpola Dec 30, 2024
446e058
Merge branch 'main' into recall
aryanpola Dec 30, 2024
c18af4f
Empty-Commit
aryanpola Dec 30, 2024
010d994
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Dec 30, 2024
9c23582
changes
aryanpola Jan 14, 2025
df42934
Protected functions
aryanpola Jan 14, 2025
dfa9046
Merge branch 'main' into recall
aryanpola Jan 14, 2025
b5bfab4
Changes in documentation
aryanpola Jan 15, 2025
576aaae
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Jan 15, 2025
da81823
Changed test cases into seperate functions
aryanpola Jan 15, 2025
f9732eb
test cases added and added range recall
aryanpola Jan 17, 2025
48238f3
udf_gamma removed from precision
aryanpola Jan 17, 2025
0561981
changes
aryanpola Jan 17, 2025
4f4f617
more changes
aryanpola Jan 17, 2025
26b5029
recommended changes
aryanpola Jan 20, 2025
fa60406
changes
aryanpola Jan 20, 2025
c48d426
Added Parameters
aryanpola Jan 23, 2025
b13ba4a
removed udf_gamma from precision
aryanpola Jan 24, 2025
d729b5d
Added binary to range
aryanpola Jan 30, 2025
3be2947
error fixing
aryanpola Jan 30, 2025
678b48b
Merge branch 'main' into recall
MatthewMiddlehurst Jan 31, 2025
843db42
test comparing prts and range_metrics
aryanpola Feb 3, 2025
00da9fb
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Feb 3, 2025
6771051
Beta parameter added in fscore
aryanpola Feb 13, 2025
51d8653
Added udf_gamma function
aryanpola Feb 13, 2025
a5c1514
Merge branch 'main' into recall
aryanpola Feb 19, 2025
6610194
f-score failing when comparing against prts
aryanpola Feb 19, 2025
6fe8c41
Merge remote-tracking branch 'origin/recall' into recall
aryanpola Feb 19, 2025
d9bb9d0
fixed f-score output
aryanpola Feb 21, 2025
a399819
alpha usage
aryanpola Feb 21, 2025
5565b07
Empty-Commit
aryanpola Feb 21, 2025
9f0e7ae
added test case to use range-based input for metrics
aryanpola Feb 21, 2025
694d643
soft dependency added
aryanpola Feb 21, 2025
0487b04
doc update
aryanpola Feb 25, 2025
96b3247
Merge branch 'main' into recall
SebastianSchmidl Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aeon/benchmarking/metrics/anomaly_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"range_pr_auc_score",
"range_pr_vus_score",
"range_roc_vus_score",
"ts_precision",
"ts_recall",
"ts_fscore",
]

from aeon.benchmarking.metrics.anomaly_detection._binary import (
Expand All @@ -35,3 +38,8 @@
range_roc_auc_score,
range_roc_vus_score,
)
from aeon.benchmarking.metrics.anomaly_detection.range_metrics import (
ts_fscore,
ts_precision,
ts_recall,
)
281 changes: 281 additions & 0 deletions aeon/benchmarking/metrics/anomaly_detection/range_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
"""Calculate Precision, Recall, and F1-Score for time series anomaly detection."""

__all__ = ["ts_precision", "ts_recall", "ts_fscore"]


def __init__(self, bias="flat", alpha=0.0, gamma=None):
assert gamma in ["reciprocal", "one", "udf_gamma"], "Invalid gamma type"
assert bias in ["flat", "front", "middle", "back"], "Invalid bias type"

self.bias = bias
self.alpha = alpha
self.gamma = gamma


def calculate_bias(position, length, bias_type="flat"):
"""Calculate bias value based on position and length.

Parameters
----------
position : int
Current position in the range
length : int
Total length of the range
bias_type : str
Type of bias to apply, Should be one of ["flat", "front", "middle", "back"].
(default: "flat")
"""
if bias_type == "flat":
return 1.0
elif bias_type == "front":
return 1.0 - (position - 1) / length
elif bias_type == "middle":
return 1.0 - abs(2 * (position - 1) / (length - 1) - 1) if length > 1 else 1.0
elif bias_type == "back":
return position / length
else:
raise ValueError(f"Invalid bias type: {bias_type}")


def gamma_select(cardinality, gamma, udf_gamma=None):
"""Select a gamma value based on the cardinality type."""
if gamma == "one":
return 1.0
elif gamma == "reciprocal":
return 1 / cardinality if cardinality > 1 else 1.0
elif gamma == "udf_gamma":
if udf_gamma is not None:
return 1.0 / udf_gamma
else:
raise ValueError("udf_gamma must be provided for 'udf_gamma' gamma type.")
else:
raise ValueError("Invalid gamma type.")


def calculate_overlap_reward_precision(pred_range, overlap_set, bias_type):
"""Overlap Reward for y_pred.

Parameters
----------
pred_range : tuple
The predicted range.
overlap_set : set
The set of overlapping positions.
bias_type : str
Type of bias to apply, Should be one of ["flat", "front", "middle", "back"].

Returns
-------
float
The weighted value for overlapping positions only.
"""
start, end = pred_range
length = end - start + 1

max_value = 0 # Total possible weighted value for all positions.
my_value = 0 # Weighted value for overlapping positions only.

for i in range(1, length + 1):
global_position = start + i - 1
bias_value = calculate_bias(i, length, bias_type)
max_value += bias_value

if global_position in overlap_set:
my_value += bias_value

return my_value / max_value if max_value > 0 else 0.0


def calculate_overlap_reward_recall(real_range, overlap_set, bias_type):
"""Overlap Reward for y_real.

Parameters
----------
real_range : tuple
The real range.
overlap_set : set
The set of overlapping positions.
bias_type : str
Type of bias to apply, Should be one of ["flat", "front", "middle", "back"].

Returns
-------
float
The weighted value for overlapping positions only.
"""
start, end = real_range
length = end - start + 1

max_value = 0.0 # Total possible weighted value for all positions.
my_value = 0.0 # Weighted value for overlapping positions only.

for i in range(1, length + 1):
global_position = start + i - 1
bias_value = calculate_bias(i, length, bias_type)
max_value += bias_value

if global_position in overlap_set:
my_value += bias_value

return my_value / max_value if max_value > 0 else 0.0


def ts_precision(y_pred, y_real, gamma="one", bias_type="flat", udf_gamma=None):
"""Precision for either a single set or the entire time series.

Parameters
----------
y_pred : list of tuples or list of list of tuples
The predicted ranges.
y_real : list of tuples
The real ranges.
gamma : str
Cardinality type. Should be one of ["reciprocal", "one", "udf_gamma"].
(default: "one")
bias_type : str
Type of bias to apply. Should be one of ["flat", "front", "middle", "back"].
(default: "flat")
udf_gamma : int or None
User-defined gamma value. (default: None)

Returns
-------
float
Range-based precision
"""
"""
example:
y_pred = [(1, 3), (5, 7)]
y_real = [(2, 6), (8, 10)]
"""
# Check if the input is a single set of predicted ranges or multiple sets
if isinstance(y_pred[0], tuple):
# y_pred is a single set of predicted ranges
total_overlap_reward = 0.0
total_cardinality = 0

for pred_range in y_pred:
overlap_set = set()
cardinality = 0

for real_start, real_end in y_real:
overlap_start = max(pred_range[0], real_start)
overlap_end = min(pred_range[1], real_end)

if overlap_start <= overlap_end:
overlap_set.update(range(overlap_start, overlap_end + 1))
cardinality += 1

overlap_reward = calculate_overlap_reward_precision(
pred_range, overlap_set, bias_type
)
gamma_value = gamma_select(cardinality, gamma, udf_gamma)

total_overlap_reward += gamma_value * overlap_reward
total_cardinality += 1

return (
total_overlap_reward / total_cardinality if total_cardinality > 0 else 0.0
)

else:
"""
example:
y_pred = [[(1, 3), (5, 7)],[(10, 12)]]
y_real = [(2, 6), (8, 10)]
"""
# y_pred as multiple sets of predicted ranges
total_precision = 0.0
total_ranges = 0

for pred_ranges in y_pred: # Iterate over all sets of predicted ranges
precision = ts_precision(
pred_ranges, y_real, gamma, bias_type, udf_gamma
) # Recursive call for single sets
total_precision += precision * len(pred_ranges)
total_ranges += len(pred_ranges)

return total_precision / total_ranges if total_ranges > 0 else 0.0


def ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0, udf_gamma=None):
"""Calculate Recall for time series anomaly detection.

Parameters
----------
y_pred : list of tuples or list of list of tuples
The predicted ranges.
y_real : list of tuples or list of list of tuples
The real ranges.
gamma : str
Cardinality type. Should be one of ["reciprocal", "one", "udf_gamma"].
(default: "one")
bias_type : str
Type of bias to apply. Should be one of ["flat", "front", "middle", "back"].
(default: "flat")
alpha : float
Weight for existence reward in recall calculation. (default: 0.0)
udf_gamma : int or None
User-defined gamma value. (default: None)

Returns
-------
float
Range-based recall
"""
if isinstance(y_real[0], tuple): # Single set of real ranges
total_overlap_reward = 0.0

for real_range in y_real:
overlap_set = set()
cardinality = 0

for pred_range in y_pred:
overlap_start = max(real_range[0], pred_range[0])
overlap_end = min(real_range[1], pred_range[1])

if overlap_start <= overlap_end:
overlap_set.update(range(overlap_start, overlap_end + 1))
cardinality += 1

# Existence Reward
existence_reward = 1.0 if overlap_set else 0.0

if overlap_set:
overlap_reward = calculate_overlap_reward_recall(
real_range, overlap_set, bias_type
)
gamma_value = gamma_select(cardinality, gamma, udf_gamma)
overlap_reward *= gamma_value
else:
overlap_reward = 0.0

# Total Recall Score
recall_score = alpha * existence_reward + (1 - alpha) * overlap_reward
total_overlap_reward += recall_score

return total_overlap_reward / len(y_real) if y_real else 0.0

elif isinstance(y_real[0], list): # Multiple sets of real ranges
total_recall = 0.0
total_real = 0

for real_ranges in y_real: # Iterate over all sets of real ranges
recall = ts_recall(y_pred, real_ranges, gamma, bias_type, alpha, udf_gamma)
total_recall += recall * len(real_ranges)
total_real += len(real_ranges)

return total_recall / total_real if total_real > 0 else 0.0


def ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0, udf_gamma=None):
"""Calculate F1-Score for time series anomaly detection."""
precision = ts_precision(y_pred, y_real, gamma, bias_type, udf_gamma)
recall = ts_recall(y_pred, y_real, gamma, bias_type, alpha, udf_gamma)

if precision + recall > 0:
fscore = 2 * (precision * recall) / (precision + recall)
else:
fscore = 0.0

return fscore
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Test cases for metrics."""

import numpy as np

from aeon.benchmarking.metrics.anomaly_detection.range_metrics import (
ts_fscore,
ts_precision,
ts_recall,
)

# Single Overlapping Range
y_pred = [(1, 4)]
y_real = [(2, 6)]

precision = ts_precision(y_pred, y_real, gamma="one", bias_type="flat")
recall = ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)
f1_score = ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)

np.testing.assert_almost_equal(precision, 0.750000, decimal=6)
np.testing.assert_almost_equal(recall, 0.600000, decimal=6)
np.testing.assert_almost_equal(f1_score, 0.666667, decimal=6)

# Multiple Non-Overlapping Ranges
y_pred = [(1, 2), (7, 8)]
y_real = [(3, 4), (9, 10)]

precision = ts_precision(y_pred, y_real, gamma="one", bias_type="flat")
recall = ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)
f1_score = ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)

np.testing.assert_almost_equal(precision, 0.000000, decimal=6)
np.testing.assert_almost_equal(recall, 0.000000, decimal=6)
np.testing.assert_almost_equal(f1_score, 0.000000, decimal=6)

# Multiple Overlapping Ranges
y_pred = [(1, 3), (5, 7)]
y_real = [(2, 6), (8, 10)]

precision = ts_precision(y_pred, y_real, gamma="one", bias_type="flat")
recall = ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)
f1_score = ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)

np.testing.assert_almost_equal(precision, 0.666667, decimal=6)
np.testing.assert_almost_equal(recall, 0.5, decimal=6)
np.testing.assert_almost_equal(f1_score, 0.571429, decimal=6)

# Nested Lists of Predictions
y_pred = [[(1, 3), (5, 7)], [(10, 12)]]
y_real = [(2, 6), (8, 10)]

precision = ts_precision(y_pred, y_real, gamma="one", bias_type="flat")
recall = ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)
f1_score = ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)

np.testing.assert_almost_equal(precision, 0.555556, decimal=6)
np.testing.assert_almost_equal(recall, 0.555556, decimal=6)
np.testing.assert_almost_equal(f1_score, 0.555556, decimal=6)

# All Encompassing Range
y_pred = [(1, 10)]
y_real = [(2, 3), (5, 6), (8, 9)]

precision = ts_precision(y_pred, y_real, gamma="one", bias_type="flat")
recall = ts_recall(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)
f1_score = ts_fscore(y_pred, y_real, gamma="one", bias_type="flat", alpha=0.0)

np.testing.assert_almost_equal(precision, 0.600000, decimal=6)
np.testing.assert_almost_equal(recall, 1.000000, decimal=6)
np.testing.assert_almost_equal(f1_score, 0.75, decimal=6)
Loading