Skip to content

Commit

Permalink
[BUG] Fixed incorrect bounding enforcement (#1871)
Browse files Browse the repository at this point in the history
* fixed incorrect warping windows

* fixed distance tests that use window
  • Loading branch information
chrisholder authored Aug 1, 2024
1 parent ae06ae3 commit a2960ad
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"euclidean": 19,
"dtw": 21,
"wdtw": 21,
"msm": 10,
"msm": 20,
"erp": 19,
"edr": 20,
"lcss": 12,
Expand Down
2 changes: 1 addition & 1 deletion aeon/clustering/tests/test_k_medoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,6 @@ def test_custom_distance_params():
# Test passing distance param
default_dist = _get_model_centres(data, distance="msm")
custom_params_dist = _get_model_centres(
data, distance="msm", distance_params={"window": 0.2}
data, distance="msm", distance_params={"window": 0.01}
)
assert not np.array_equal(default_dist, custom_params_dist)
25 changes: 13 additions & 12 deletions aeon/distances/_edr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from numba import njit
from numba.typed import List as NumbaList

from aeon.distances._alignment_paths import (
_add_inf_to_out_of_bounds_cost_matrix,
compute_min_return_path,
)
from aeon.distances._alignment_paths import compute_min_return_path
from aeon.distances._bounding_matrix import create_bounding_matrix
from aeon.distances._euclidean import _univariate_euclidean_distance
from aeon.distances._utils import _convert_to_list, _is_multivariate
Expand Down Expand Up @@ -206,7 +203,15 @@ def _edr_cost_matrix(
if epsilon is None:
epsilon = float(max(np.std(x), np.std(y))) / 4

cost_matrix = np.zeros((x_size + 1, y_size + 1))
cost_matrix = np.full((x_size + 1, y_size + 1), np.inf)

for i in range(1, x_size + 1):
if bounding_matrix[i - 1, 0]:
cost_matrix[i, 0] = 0
for j in range(y_size):
if bounding_matrix[0, j - 1]:
cost_matrix[0, j] = 0
cost_matrix[0, 0] = 0

for i in range(1, x_size + 1):
for j in range(1, y_size + 1):
Expand Down Expand Up @@ -419,12 +424,8 @@ def edr_alignment_path(
>>> edr_alignment_path(x, y)
([(0, 0), (1, 1), (2, 2), (3, 3)], 0.25)
"""
x_size = x.shape[-1]
y_size = y.shape[-1]
bounding_matrix = create_bounding_matrix(x_size, y_size, window, itakura_max_slope)
cost_matrix = edr_cost_matrix(x, y, window, epsilon, itakura_max_slope)
# Need to do this because the cost matrix contains 0s and not inf in out of bounds
cost_matrix = _add_inf_to_out_of_bounds_cost_matrix(cost_matrix, bounding_matrix)
return compute_min_return_path(cost_matrix), float(
cost_matrix[x_size - 1, y_size - 1] / max(x_size, y_size)
return (
compute_min_return_path(cost_matrix),
cost_matrix[x.shape[-1] - 1, y.shape[-1] - 1] / max(x.shape[-1], y.shape[-1]),
)
16 changes: 4 additions & 12 deletions aeon/distances/_erp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from numba import njit
from numba.typed import List as NumbaList

from aeon.distances._alignment_paths import (
_add_inf_to_out_of_bounds_cost_matrix,
compute_min_return_path,
)
from aeon.distances._alignment_paths import compute_min_return_path
from aeon.distances._bounding_matrix import create_bounding_matrix
from aeon.distances._euclidean import _univariate_euclidean_distance
from aeon.distances._utils import _convert_to_list, _is_multivariate
Expand Down Expand Up @@ -209,13 +206,13 @@ def _erp_cost_matrix(
x_size = x.shape[1]
y_size = y.shape[1]

cost_matrix = np.zeros((x_size + 1, y_size + 1))

cost_matrix = np.full((x_size + 1, y_size + 1), np.inf)
gx_distance, x_sum = _precompute_g(x, g, g_arr)
gy_distance, y_sum = _precompute_g(y, g, g_arr)

cost_matrix[1:, 0] = x_sum
cost_matrix[0, 1:] = y_sum
cost_matrix[0, 0] = 0.0

for i in range(1, x_size + 1):
for j in range(1, y_size + 1):
Expand Down Expand Up @@ -458,12 +455,7 @@ def erp_alignment_path(
>>> erp_alignment_path(x, y)
([(0, 0), (1, 1), (2, 2), (3, 3)], 2.0)
"""
bounding_matrix = create_bounding_matrix(
x.shape[-1], y.shape[-1], window, itakura_max_slope
)
cost_matrix = _add_inf_to_out_of_bounds_cost_matrix(
erp_cost_matrix(x, y, window, g, g_arr), bounding_matrix
)
cost_matrix = erp_cost_matrix(x, y, window, g, g_arr)
return (
compute_min_return_path(cost_matrix),
cost_matrix[x.shape[-1] - 1, y.shape[-1] - 1],
Expand Down
20 changes: 7 additions & 13 deletions aeon/distances/_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from numba import njit
from numba.typed import List as NumbaList

from aeon.distances._alignment_paths import (
_add_inf_to_out_of_bounds_cost_matrix,
compute_min_return_path,
)
from aeon.distances._alignment_paths import compute_min_return_path
from aeon.distances._bounding_matrix import create_bounding_matrix
from aeon.distances._squared import _univariate_squared_distance
from aeon.distances._utils import _convert_to_list, _is_multivariate
Expand Down Expand Up @@ -262,7 +259,7 @@ def _independent_cost_matrix(
) -> np.ndarray:
x_size = x.shape[0]
y_size = y.shape[0]
cost_matrix = np.zeros((x_size, y_size))
cost_matrix = np.full((x_size, y_size), np.inf)
cost_matrix[0, 0] = np.abs(x[0] - y[0])

for i in range(1, x_size):
Expand Down Expand Up @@ -293,7 +290,7 @@ def _msm_dependent_cost_matrix(
) -> np.ndarray:
x_size = x.shape[1]
y_size = y.shape[1]
cost_matrix = np.zeros((x_size, y_size))
cost_matrix = np.full((x_size, y_size), np.inf)
cost_matrix[0, 0] = np.sum(np.abs(x[:, 0] - y[:, 0]))

for i in range(1, x_size):
Expand Down Expand Up @@ -546,11 +543,8 @@ def msm_alignment_path(
>>> msm_alignment_path(x, y)
([(0, 0), (1, 1), (2, 2), (3, 3)], 2.0)
"""
x_size = x.shape[-1]
y_size = y.shape[-1]
bounding_matrix = create_bounding_matrix(x_size, y_size, window, itakura_max_slope)
cost_matrix = msm_cost_matrix(x, y, window, independent, c, itakura_max_slope)

# Need to do this because the cost matrix contains 0s and not inf in out of bounds
cost_matrix = _add_inf_to_out_of_bounds_cost_matrix(cost_matrix, bounding_matrix)
return compute_min_return_path(cost_matrix), cost_matrix[x_size - 1, y_size - 1]
return (
compute_min_return_path(cost_matrix),
cost_matrix[x.shape[-1] - 1, y.shape[-1] - 1],
)
15 changes: 3 additions & 12 deletions aeon/distances/_twe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from numba import njit
from numba.typed import List as NumbaList

from aeon.distances._alignment_paths import (
_add_inf_to_out_of_bounds_cost_matrix,
compute_min_return_path,
)
from aeon.distances._alignment_paths import compute_min_return_path
from aeon.distances._bounding_matrix import create_bounding_matrix
from aeon.distances._euclidean import _univariate_euclidean_distance
from aeon.distances._utils import _convert_to_list, _is_multivariate
Expand Down Expand Up @@ -200,9 +197,8 @@ def _twe_cost_matrix(
) -> np.ndarray:
x_size = x.shape[1]
y_size = y.shape[1]
cost_matrix = np.zeros((x_size, y_size))
cost_matrix[0, 1:] = np.inf
cost_matrix[1:, 0] = np.inf
cost_matrix = np.full((x_size, y_size), np.inf)
cost_matrix[0, 0] = 0.0

del_add = nu + lmbda

Expand Down Expand Up @@ -461,12 +457,7 @@ def twe_alignment_path(
>>> twe_alignment_path(x, y)
([(0, 0), (1, 1), (2, 2), (3, 3)], 2.0)
"""
bounding_matrix = create_bounding_matrix(
x.shape[-1], y.shape[-1], window, itakura_max_slope
)
cost_matrix = twe_cost_matrix(x, y, window, nu, lmbda, itakura_max_slope)
# Need to do this because the cost matrix contains 0s and not inf in out of bounds
cost_matrix = _add_inf_to_out_of_bounds_cost_matrix(cost_matrix, bounding_matrix)
return (
compute_min_return_path(cost_matrix),
cost_matrix[x.shape[-1] - 1, y.shape[-1] - 1],
Expand Down
10 changes: 5 additions & 5 deletions aeon/distances/tests/test_distance_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@
"squared": 384147.0,
"dtw": [384147.0, 315012.0, 275854.0],
"wdtw": [137927.0, 68406.15849, 2.2296],
"erp": [168.0, 1107.0, 2275.0],
"erp": [2275.0, 2275.0, 2275.0],
"lcss": [1.0, 0.45833, 0.08333],
"edr": [1.0, 0.58333, 0.125],
"ddtw": [80806.0, 76289.0625, 76289.0625],
"wddtw": [38144.53125, 19121.4927, 1.34957],
"twe": [137.001, 567.0029999999999, 3030.036000000001],
"twe": [4536.0, 3192.0220, 3030.036000000001],
"msm_ind": [1515.0, 1517.8000000000004, 1557.0], # msm with independent distance
"msm_dep": [1897.0, 1898.6000000000001, 1921.0], # msm with dependent distance
}
Expand All @@ -65,16 +65,16 @@
"dtw": [757.259719, 330.834497, 330.834497],
"wdtw": [165.41724, 3.308425, 0],
"msm": [70.014828, 89.814828, 268.014828],
"erp": [0.2086269, 2.9942540, 102.097904],
"erp": [169.3715, 102.0979, 102.097904],
"edr": [1.0, 0.26, 0.07],
"lcss": [1.0, 0.26, 0.05],
"ddtw": [297.18771, 160.51311645984856, 160.29823],
"wddtw": [80.149117, 1.458858, 0.0],
"twe": [1.001, 12.620531031063596, 173.3596688781867],
"twe": [338.4842162018424, 173.35966887818674, 173.3596688781867],
# msm with independent distance
"msm_ind": [84.36021099999999, 140.13788899999997, 262.6939920000001],
# msm with dependent distance
"msm_dep": [33.068257441993, 71.14080843368329, 190.73978686253804],
"msm_dep": [33.06825, 71.1408, 190.7397],
}


Expand Down
14 changes: 7 additions & 7 deletions aeon/testing/expected_results/expected_distance_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,16 @@
[0.3409987716261595, 2.7187979513671015],
],
"lcss": [[0.30000000000000004, 1.0], [0.4, 1.0], [0.4, 1.0]],
"edr": [[0.3, 0.3], [0.1, 0.1], [0.5, 1.0]],
"edr": [[0.7, 1.0], [1.0, 1.0], [0.5, 1.0]],
"twe": [
[5.087449975445656, 15.161815735222117],
[1.1499446039354893, 5.995665808293953],
[20.507374214378885, 78.81976840746147],
[21.48930550350685, 82.55907793607852],
[15.243281318807819, 77.81976840746147],
[27.97089924329228, 83.97624505343292],
],
"msm": [
[4.080245996952201, 43.583053575960584],
[1.0, 15.829914369482566],
[12.099213975730216, 92.75733240032741],
[12.153142672084059, 102.90914530531768],
[12.023580258367444, 88.80013932627139],
[7.115130579734542, 61.80633627614831],
],
Expand All @@ -150,8 +150,8 @@
],
"sbd": [[0.13378563362841267, 0.12052110294129567]],
"erp": [
[6.1963403666089425, 23.958805888780923],
[2.2271884807416047, 9.205416143392629],
[13.782010409379064, 44.3099600330504],
[13.782010409379064, 44.3099600330504],
[12.782010409379064, 44.3099600330504],
[15.460501609188993, 44.3099600330504],
],
Expand Down

0 comments on commit a2960ad

Please sign in to comment.