diff --git a/.coveragerc b/.coveragerc index 39de6da..876c825 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,4 @@ +# type: ignore # .coveragerc to control coverage.py [run] branch = True @@ -31,3 +32,4 @@ omit = */algorithms_educational.py */algorithms_old.py */smawk_old.py + */utils_for_test.py diff --git a/.pylintrc b/.pylintrc index b9bac9a..ed4e228 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,8 +1,10 @@ [MASTER] disable= missing-module-docstring, missing-class-docstring, missing-function-docstring, missing-final-newline, C0325 +ignore-paths=.coveragerc +ignore=.coveragerc [FORMAT] max-line-length=200 # allow 1 or 2 length variable names good-names-rgxs=^[_a-zGW][_a-z0-9L]?$ -good-names=R,D,T,D_row,S,A,F,H,F_vals, F_val, H_vals, M, SMALL_VAL, LARGE_VAL, MicroaggWilberCalculator_edu, N_vals, N +good-names=R,D,T,D_row,S,A,F,H,F_vals, F_val, H_vals, M, SMALL_VAL, LARGE_VAL, MicroaggWilberCalculator_edu, N_vals, N, setUp, tearDown diff --git a/setup.cfg b/setup.cfg index 8efbbb6..8b05f1e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -98,6 +98,10 @@ testpaths = tests # slow: mark tests as slow (deselect with '-m "not slow"') # system: mark end-to-end system tests +[tool.pylint.MASTER] +ignore = .coveragerc +ignore-paths = .coveragerc + [devpi:upload] # Options for the devpi: PyPI server and packaging tool # VCS export must be deactivated since we are using setuptools-scm @@ -162,3 +166,4 @@ omit = */algorithms_educational.py */algorithms_old.py */smawk_old.py + */utils_for_test.py diff --git a/src/microagg1d/__init__.py b/src/microagg1d/__init__.py index 8cdec0d..8d3c0e1 100644 --- a/src/microagg1d/__init__.py +++ b/src/microagg1d/__init__.py @@ -1,3 +1,3 @@ -from microagg1d.main import optimal_univariate_microaggregation_1d +from microagg1d.main import univariate_microaggregation -microagg1d = optimal_univariate_microaggregation_1d # shorthand for the above +microagg1d = univariate_microaggregation # shorthand for the above diff --git a/src/microagg1d/cost_maxdist.py b/src/microagg1d/cost_maxdist.py index 78d58dc..572c6a2 100644 --- a/src/microagg1d/cost_maxdist.py +++ b/src/microagg1d/cost_maxdist.py @@ -29,7 +29,7 @@ def __init__(self, arr, k, F_vals): self.SMALL_VAL = (arr[-1] - arr[0]) * n self.LARGE_VAL = self.SMALL_VAL * (1 + n) - def calc(self, i, j): # i <-> j interchanged is not a bug! + def calc(self, i, j): """This function computes the w_{ij} values introduced""" if j <= i: return np.inf diff --git a/src/microagg1d/cost_round.py b/src/microagg1d/cost_round.py index a5df3d7..93bccef 100644 --- a/src/microagg1d/cost_round.py +++ b/src/microagg1d/cost_round.py @@ -33,7 +33,7 @@ def __init__(self, arr, k, F_vals): self.SMALL_VAL = (arr[-1] - arr[0]) * n self.LARGE_VAL = self.SMALL_VAL * (1 + n) - def calc(self, i, j): # i <-> j interchanged is not a bug! + def calc(self, i, j): """This function computes the w_{ij} values introduced""" if j <= i: return np.inf @@ -88,7 +88,7 @@ def __init__(self, arr, k, F_vals): self.SMALL_VAL = (arr[-1] - arr[0]) * n self.LARGE_VAL = self.SMALL_VAL * (1 + n) - def calc(self, i, j): # i <-> j interchanged is not a bug! + def calc(self, i, j): """This function computes the w_{ij} values introduced""" if j <= i: return np.inf diff --git a/src/microagg1d/cost_sae.py b/src/microagg1d/cost_sae.py index b330269..c0aac1a 100644 --- a/src/microagg1d/cost_sae.py +++ b/src/microagg1d/cost_sae.py @@ -44,7 +44,7 @@ def __init__(self, arr, k, F_vals): self.SMALL_VAL = (arr[-1] - arr[0]) * n self.LARGE_VAL = self.SMALL_VAL * (1 + n) - def calc(self, i, j): # i <-> j interchanged is not a bug! + def calc(self, i, j): """This function computes the w_{ij} values introduced""" if j <= i: return np.inf diff --git a/src/microagg1d/cost_sse.py b/src/microagg1d/cost_sse.py index 2151516..7eeb409 100644 --- a/src/microagg1d/cost_sse.py +++ b/src/microagg1d/cost_sse.py @@ -279,7 +279,7 @@ def compute_sse_sorted_stable(v, clusters_sorted): r = 0 while r < len(v): r = left - while clusters_sorted[left] == clusters_sorted[r] and r < len(v): + while r < len(v) and clusters_sorted[left] == clusters_sorted[r]: r += 1 # r-=1 mean = np.mean(v[left:r]) diff --git a/src/microagg1d/main.py b/src/microagg1d/main.py index be46dc7..2a0c6e7 100644 --- a/src/microagg1d/main.py +++ b/src/microagg1d/main.py @@ -22,7 +22,7 @@ def undo_argsort(sorted_arr, sort_order): return sorted_arr[revert] -def optimal_univariate_microaggregation_1d(x, k, method="auto", stable=1, cost="sse"): +def univariate_microaggregation(x, k, method="auto", stable=1, cost="sse"): """Performs optimal 1d univariate microaggregation""" x = np.squeeze(np.asarray(x)) assert len(x.shape) == 1, "provided array is not 1d" @@ -70,4 +70,6 @@ def optimal_univariate_microaggregation_1d(x, k, method="auto", stable=1, cost=" clusters = _rounddown_user(x, k, method) elif cost == "maxdist": clusters = _maxdist_user(x, k, method) + else: + raise NotImplementedError("Should not be reachable") return undo_argsort(clusters, order) diff --git a/src/microagg1d/utils_for_test.py b/src/microagg1d/utils_for_test.py index 663377c..2acb035 100644 --- a/src/microagg1d/utils_for_test.py +++ b/src/microagg1d/utils_for_test.py @@ -1,172 +1,420 @@ -import numba import numpy as np +DEBUG = False + def is_jitlass(func): return str(type(func)) == "" -def remove_numba(func, seen=None, allowed_packages=tuple()): - # print(func, type(func)) - - if seen is None: - seen = {} - if hasattr(func, "py_func"): - # clean_up["self"] = func - seen[func] = func.py_func - func = func.py_func - - if isinstance(func, type(remove_numba)): - to_iter = func.__globals__ - - def set_func(key, value): - to_iter[key] = value +class RemoveNumbaSettings: + def __init__(self, allowed_packes=tuple()): + self.allowed_packages = tuple(allowed_packes) + ("__main__",) + self.search_deep = True + self.depth = 0 + self.max_depth = 1000 - elif is_jitlass(func): - new_methods = {} - to_iter = func.class_type.jit_methods + def is_allowed_module_string(self, module_string): + if module_string.startswith("numba.experimental.jitclass.base"): + return True + if module_string == "": + return True + return any( + package_str in module_string for package_str in self.allowed_packages + ) - def set_func(key, value): - new_methods[key] = value - elif str(type(func)) == "" and func.__package__ in allowed_packages: +DefaultSettings = RemoveNumbaSettings() - def set_func(key, value): - setattr(func, key, value) - to_iter = {key: getattr(func, key) for key in dir(func)} +def track_namespace(namespace, seen_namespaces): + if namespace.id() in seen_namespaces: + return True else: - raise NotImplementedError(type(func)) + seen_namespaces[namespace.id()] = namespace + return False - clean_up = iter_children(to_iter, seen, set_func, allowed_packages) - if is_jitlass(func): - return type(func.class_type.class_name, (), new_methods), clean_up - return func, clean_up +def is_right_module(obj, settings): + if str(type(obj)) == "" and not settings.is_allowed_module_string( + obj.__package__ + ): + return False + if repr(obj).startswith("" - and maybe_func.__package__ in allowed_packages - ): - # print("module", maybe_func) - if maybe_func in seen: - continue - seen[maybe_func] = None - # print(seen) - non_numba_handle, handle_cleanup = remove_numba( - maybe_func, seen, allowed_packages - ) - clean_up["__module__" + key] = handle_cleanup - continue - - if not ( - isinstance( - maybe_func, (type(remove_numba), numba.core.registry.CPUDispatcher) - ) - or is_jitlass(maybe_func) - ): - continue - if maybe_func in seen: - # print("Seen") - # print(maybe_func) - clean_up[key] = maybe_func - clean_up["__children__" + key] = {} - set_func(key, seen[maybe_func]) - continue - if hasattr(maybe_func, "py_func") or is_jitlass(maybe_func): - non_numba_handle, handle_cleanup = remove_numba( - maybe_func, seen, allowed_packages - ) - clean_up[key] = maybe_func - clean_up["__children__" + key] = handle_cleanup - set_func(key, non_numba_handle) - return clean_up +class DictWrapper: + def __init__(self, d, parents, name=None): + self.d = d + self.parents = parents + if name is None: + self._name = self.d["__name__"] + else: + self._name = name + def id(self): + return id(self.d) -def restore_numba(func, clean_up, parent=None): - # print(func) - if hasattr(func, "py_func"): - func = func.py_func - # print(func) - if isinstance(func, type(restore_numba)): + def set(self, key, value): + self.d[key] = value - def set_func(key, value): - func.__globals__[key] = value + def get(self, key): + return self.d[key] - def get_func(key): - return func.__globals__[key] + def items(self): + return self.d.items() - elif str(type(func)) == "": + def __eq__(self, other): + return self.d is other.d - def set_func(key, value): - setattr(func, key, value) + def __repr__(self): + return f"DictWrapper({self.parents},{filter_globals(self.d)})" - def get_func(key): - return getattr(func, key) + def name(self): + return self._name - elif str(type(func)) == "": - def set_func(key, value): # pylint: disable=unused-argument - pass - # print(key, value) - # parent.class_type.jit_methods[key]=value - # setattr(func, key, value) +class ModuleWrapper: + def __init__(self, module, parents): + self.module = module + self.parents = parents - def get_func(key): - return parent.class_type.jit_methods[key] + def id(self): + return id(self.module) + + def set(self, key, value): + setattr(self.module, key, value) + + def get(self, key): + return getattr(self.module, key) + + def items(self): + return {key: getattr(self.module, key) for key in dir(self.module)} + + def __repr__(self): + return f"ModuleWrapper({self.parents},{filter_globals(self.items())})" + + def name(self): + return repr(self.module) + + +class ClassWrapper: + def __init__(self, cls, parents): + self.d = cls.__dict__ + self.parents = parents + self._name = str(type(cls)).split("'")[1] + + def id(self): + return id(self.d) + + def set(self, key, value): + self.d[key] = value + + def get(self, key): + return self.d[key] + + def items(self): + return self.d.items() + + def __repr__(self): + return f"ClassWrapper({self.parents},{filter_globals(dict(self.items()))})" + + def name(self): + return self._name + + +class TypeWrapper: + def __init__(self, cls, parents): + self.cls = cls + self.parents = parents + self._name = str(cls).split("'")[1] + + def id(self): + return id(self.cls) + + def set(self, key, value): + setattr(self.cls, key, value) + + def get(self, key): + return getattr(self.cls, key) + + def items(self): + return {key: getattr(self.cls, key) for key in dir(self.cls)}.items() + + def __repr__(self): + return f"TypeWrapper({self.parents},{filter_globals(dict(self.items()))})" + + def name(self): + return self._name + + +def collect_unseen_namespaces(obj, seen_namespaces, settings, parents=tuple()): + def my_print(*args): + print(" " * settings.depth, *args) + # my_print(obj, type(obj)) + # need to do modules firs because they have no __module__ + if str(type(obj)) == "" and not settings.is_allowed_module_string( + obj.__package__ + ): + return + # my_print(" ",obj, obj.__module__) + # my_print(" ",obj.__init__.__globals__) + if not settings.is_allowed_module_string(obj.__module__): + # print("MODULE", obj.__module__) + return + # my_print(dir(obj)) + if hasattr(obj, "py_func"): # is @njit function + namespace = DictWrapper(obj.py_func.__globals__, parents) + if track_namespace(namespace, seen_namespaces): + return + elif isinstance(obj, type(is_right_module)): + # my_print("MODULE", obj.__module__) + namespace = DictWrapper(obj.__globals__, parents) + if track_namespace(namespace, seen_namespaces): + return + elif is_jitlass(obj): + # print("JITCLASS") + namespace = DictWrapper(obj.class_type.jit_methods, parents, name=str(obj)) + # print(obj.class_type.jit_methods) + if track_namespace(namespace, seen_namespaces): + return + elif ( + str(type(obj)) == "" + and obj.__package__ in settings.allowed_packges + ): + namespace = ModuleWrapper(obj, parents) + if track_namespace(ModuleWrapper, seen_namespaces): + return + elif str(type(obj)) == "": + namespace = TypeWrapper(obj, parents) + elif str(type(obj)).startswith("", + "", + "", + ): continue - if get_func(key) is value and not hasattr(value, "py_func"): + if isinstance(new_obj, (list, dict, np.ndarray, tuple)): continue - restore_numba(get_func(key), clean_up["__children__" + key], value) - set_func(key, value) - - -def remove_from_class(cls, allowed_packages=tuple()): - clean_ups = {} - seen = {} - for key, value in cls.__dict__.items(): - if hasattr(value, "__module__"): - _, handle_cleanup = remove_numba(value, seen, allowed_packages) - clean_ups[key] = (value, handle_cleanup) - # print(key) - # print(value.__module__) - return clean_ups - - -def restore_to_class(clean_ups): - for _, (handle, clean_up) in clean_ups.items(): - restore_numba(handle, clean_up) - # cls.__dict__[key]=handle - - -class NoNumbaTestCase: - def setUp(self): - for cls in self.__class__.__bases__: - if cls is NoNumbaTestCase: + if isinstance(new_obj, (int, float, str)): + continue + if new_obj is None: + continue + if str(new_obj).startswith("ModuleSpec("): + continue + # print("BB", key, type(new_obj)) + if not is_right_module(new_obj, settings): + continue + # if key in ("_smawk_iter",): + # print(key, type(new_obj), parents) + # my_print("into", key, type(new_obj)) + settings.depth += 1 + collect_unseen_namespaces( + new_obj, + seen_namespaces, + settings, + parents + (namespace.name() + "." + key,), + ) + settings.depth -= 1 + + +class UndoNjitFunction: + def __init__(self, name): + self.name = name + self.njit_function = None + + def remove_numba(self, namespace): + self.njit_function = namespace.get(self.name) + namespace.set(self.name, self.njit_function.py_func) + + def add_numba(self, namespace): + namespace.set(self.name, self.njit_function) + + def __repr__(self): + return f"UndoNjitFunction({self.name})" + + def __eq__(self, other): + if not isinstance(other, UndoNjitFunction): + return False + return self.name == other.name + + +class UndoJitclassFunction: + def __init__(self, name): + self.name = name + self.njit_class = None + self.python_class = None + + def remove_numba(self, namespace): + self.njit_class = namespace.get(self.name) + if self.njit_class is None: + raise ValueError() + # print(self.njit_class.class_type.jit_methods) + new_methods = { + key: getattr(value, "py_func", value) + for key, value in self.njit_class.class_type.jit_methods.items() + } + + self.python_class = type(self.njit_class.class_type.class_name, (), new_methods) + namespace.set(self.name, self.python_class) + + def add_numba(self, namespace): + namespace.set(self.name, self.njit_class) + + def __eq__(self, other): + if not isinstance(other, UndoJitclassFunction): + return False + return self.name == other.name + + def __repr__(self): + if self.njit_class is None: + return f"UndoJitclassFunction({self.name})" + return f"UndoJitclassFunction({self.njit_class.class_type.class_name})" + + +def is_trivial_object(key, obj): + if key in ("__loader__", "__builtins__", "__dict__", "__weakref__"): + return True + if isinstance(obj, (list, dict, np.ndarray, tuple)): + return True + if isinstance(obj, (int, float, str)): + return True + if obj is None: + return True + return False + + +def remove_numba_from_namespaces(namespaces): + list_of_undos = [] + for namespace in namespaces.values(): + undo_list = [] + list_of_undos.append(undo_list) + for key, obj in namespace.items(): + if is_trivial_object(key, obj): continue - self.cleanup = remove_from_class(cls, allowed_packages=["fast1dkmeans"]) - break - - def tearDown(self) -> None: - restore_to_class(self.cleanup) + if hasattr(obj, "py_func"): + # print() + # print(filter_globals(namespace)) + # print(key, obj) + undo_list.append(UndoNjitFunction(key)) + elif is_jitlass(obj): + undo_list.append(UndoJitclassFunction(key)) + else: + pass + # print(obj, type(obj)) + for changer in undo_list: + # print(changer) + # print(namespace) + changer.remove_numba(namespace) + # print() + # print(changer) + # print(namespace) + return list_of_undos + + +def filter_globals(d): + out = {} + for key, value in d.items(): + if key.endswith("Error"): + continue + if key in ( + "quit", + "copyright", + "np", + "__spec__", + "__file__", + "__loader__", + "__cached__", + ): + continue + if key in ("__builtins__", "unittest"): + continue + out[key] = value + return out + + +def remove_from_class(cls, settings=None, allowed_packages=None): + if settings is None: + if allowed_packages is None: + raise ValueError() + # print("CLS MODULE", cls.__module__) + settings = RemoveNumbaSettings( + allowed_packes=tuple(allowed_packages) + (cls.__module__,) + ) + namespaces = {} + collect_unseen_namespaces(cls, namespaces, settings) + # print(len(namespaces)) + list_of_undos = remove_numba_from_namespaces(namespaces) + + namespaces2 = {} + collect_unseen_namespaces(cls, namespaces2, settings) + list_of_undos2 = remove_numba_from_namespaces(namespaces2) + + if DEBUG: + print("-----------------") + print([namespace.name() for namespace in namespaces.values()]) + print("-----------------") + # print(len(namespaces)) + + for (_, namespace), l in zip(namespaces2.items(), list_of_undos2): + if not len(l) == 0: + # print() + # print("?????") + # print(namespace) + # print(l) + assert False + + return namespaces, list_of_undos + + +def restore_to_class(stuff): + namespaces, list_of_undos = stuff + for (_, namespace), undos in zip(namespaces.items(), list_of_undos): + for undo in undos: + # print(undo) + undo.add_numba(namespace) + + +def namespaces_equal(namespaces1, namespaces2): + return list(namespaces1.keys()) == list(namespaces2.keys()) and all( + [ns1 == ns2 for ns1, ns2 in zip(namespaces1.values(), namespaces2.values())] + ) + + +def undos_equal(undos1, undos2): + for l1, l2 in zip(undos1, undos2): + if not all([u1 == u2 for u1, u2 in zip(l1, l2)]): + return False + return True + + +def cleanups_equal(cleanup1, cleanup2): + """Checks whether two cleanups are equal""" + namespaces1, undos1 = cleanup1 + namespaces2, undos2 = cleanup2 + + return namespaces_equal(namespaces1, namespaces2) and undos_equal(undos1, undos2) diff --git a/tests/test_algorithms_maxdist.py b/tests/test_algorithms_maxdist.py index f330021..b8d8c11 100644 --- a/tests/test_algorithms_maxdist.py +++ b/tests/test_algorithms_maxdist.py @@ -5,7 +5,7 @@ from microagg1d.common import compute_cluster_cost_sorted from microagg1d.cost_maxdist import AdaptedMaxDistCostCalculator, MaxDistCostCalculator -from microagg1d.main import optimal_univariate_microaggregation_1d +from microagg1d.main import univariate_microaggregation from microagg1d.user_facing import _maxdist_user from microagg1d.utils_for_test import remove_from_class, restore_to_class @@ -63,9 +63,7 @@ def test_maxdist_staggered(self): my_test_algorithm(self, partial(_maxdist_user, algorithm="staggered")) def test_maxdist_main(self): - my_test_algorithm( - self, partial(optimal_univariate_microaggregation_1d, cost="maxdist") - ) + my_test_algorithm(self, partial(univariate_microaggregation, cost="maxdist")) class Test7Elements(Test8Elements): diff --git a/tests/test_algorithms_sse.py b/tests/test_algorithms_sse.py index 57d9e5c..a40c449 100644 --- a/tests/test_algorithms_sse.py +++ b/tests/test_algorithms_sse.py @@ -6,7 +6,7 @@ from microagg1d.algorithms_educational import conventional_algorithm from microagg1d.common import compute_cluster_cost_sorted from microagg1d.cost_sse import SSECostCalculator -from microagg1d.main import optimal_univariate_microaggregation_1d +from microagg1d.main import univariate_microaggregation from microagg1d.user_facing import ( _sse_galil_park2, _sse_simple_dynamic_program, @@ -126,23 +126,19 @@ def test_sse_galil_park2_stable_0(self): # test main def test_optimal_univariate_microaggregation_simple(self): - my_test_algorithm( - self, partial(optimal_univariate_microaggregation_1d, method="simple") - ) + my_test_algorithm(self, partial(univariate_microaggregation, method="simple")) def test_optimal_univariate_microaggregation_wilber(self): - my_test_algorithm( - self, partial(optimal_univariate_microaggregation_1d, method="wilber") - ) + my_test_algorithm(self, partial(univariate_microaggregation, method="wilber")) def test_optimal_univariate_microaggregation_galil_park(self): my_test_algorithm( - self, partial(optimal_univariate_microaggregation_1d, method="galil_park") + self, partial(univariate_microaggregation, method="galil_park") ) def test_optimal_univariate_microaggregation_staggered(self): my_test_algorithm( - self, partial(optimal_univariate_microaggregation_1d, method="staggered") + self, partial(univariate_microaggregation, method="staggered") ) diff --git a/tests/test_generation.py b/tests/test_generation.py index 816dcb0..1a814c4 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -4,7 +4,12 @@ import numpy as np from numpy.testing import assert_array_equal -from microagg1d.generation import create_pair_arange, create_pair_known_sizes +from microagg1d.generation import ( + create_pair_arange, + create_pair_const_size, + create_pair_known_sizes, +) +from microagg1d.utils_for_test import remove_from_class, restore_to_class class TestArangeGeneration(unittest.TestCase): @@ -122,6 +127,23 @@ def test_create_pair_known_sizes_raises_large_epsilon(self): with self.assertRaises(ValueError): create_pair_known_sizes([4, 4], k=4, epsilon=0.2) + def test_create_pair_cost_size(self): + arr, _ = create_pair_const_size(10, 3, 0.1) + assert_array_equal(arr, np.array([0, 0.1, 0.2, 1, 1.1, 1.2, 2, 2.1, 2.2, 2.3])) + + arr, _ = create_pair_const_size(9, 3, 0.1) + assert_array_equal(arr, np.array([0, 0.1, 0.2, 1, 1.1, 1.2, 2, 2.1, 2.2])) + + +class TestArangeGenerationNonCompiled(TestArangeGeneration): + def setUp(self): + self.cleanup = remove_from_class( + self.__class__.__bases__[0], allowed_packages=["microagg1d"] + ) + + def tearDown(self) -> None: + restore_to_class(self.cleanup) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_main.py b/tests/test_main.py index d71b45f..2f7398b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,7 +4,9 @@ import numpy as np from numpy.testing import assert_array_equal -from microagg1d.main import optimal_univariate_microaggregation_1d, undo_argsort +from microagg1d.generation import create_pair_arange +from microagg1d.main import undo_argsort, univariate_microaggregation +from microagg1d.utils_for_test import remove_from_class, restore_to_class # use with k=5 interesting_arr = np.array( @@ -36,7 +38,14 @@ def get_random_arr(seed, n): return x, x_sorted, order -class RegularizedKmeans(unittest.TestCase): +class TestMainLarger(unittest.TestCase): + def test_microagg_larger_input(self): + arr, solution = create_pair_arange(150, 22) + result = univariate_microaggregation(arr, 22) + np.testing.assert_array_equal(solution, result) + + +class TestMain(unittest.TestCase): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.arr = np.array([1, 1, 1, 1.1, 5, 5, 5]) @@ -56,44 +65,62 @@ def test_undo_argsort_random(self): def test_microagg_raises(self): with self.assertRaises(AssertionError): - optimal_univariate_microaggregation_1d(np.random.rand(4, 4), 1) + univariate_microaggregation(np.random.rand(4, 4), 1) with self.assertRaises(AssertionError): - optimal_univariate_microaggregation_1d(self.arr, 0) - - def test_microagg(self): + univariate_microaggregation(self.arr, 0) + + def test_microagg_combinations(self): + for method in ("auto", "wilber", "galil_park", "staggered"): + for cost in ("sse", "sae", "roundup", "rounddown", "maxdist"): + with self.subTest(msg=f"{method}_{cost}"): + for k, solution in zip(range(1, len(self.arr) + 1), self.solutions): + result = univariate_microaggregation( + self.arr.copy(), k, method=method, cost=cost + ) + np.testing.assert_array_equal(solution, result, f"k={k}") + + def test_microagg_no_arguments(self): for k, solution in zip(range(1, len(self.arr) + 1), self.solutions): - result = optimal_univariate_microaggregation_1d(self.arr.copy(), k) + result = univariate_microaggregation(self.arr.copy(), k) np.testing.assert_array_equal(solution, result, f"k={k}") def test_example_usage(self): - import microagg1d # pylint: disable=import-outside-toplevel + # pylint: disable=redefined-outer-name,reimported,import-outside-toplevel + from microagg1d import univariate_microaggregation - x = [5, 1, 1, 1.1, 5, 1, 5] + x = [5, 1, 1, 1.1, 5, 1, 5.1] - clusters = microagg1d.optimal_univariate_microaggregation_1d(x, k=3) + clusters = univariate_microaggregation(x, k=3) print(clusters) # [1 0 0 0 1 0 1] - np.testing.assert_array_equal(clusters, [1, 0, 0, 0, 1, 0, 1], f"k={3}") - clusters2 = microagg1d.optimal_univariate_microaggregation_1d( - x, k=3, method="wilber" - ) # explicitly choose method + # explicitly choose method / algorithm + clusters2 = univariate_microaggregation(x, k=3, method="wilber") print(clusters2) # [1 0 0 0 1 0 1] - # may opt to get increased speed at cost of stability - # this is usually not a problem on small datasets such as here - clusters3 = microagg1d.optimal_univariate_microaggregation_1d( - x, k=3, stable=False - ) + # choose a different cost (sae / sse / roundup / rounddown / maxdist) + clusters3 = univariate_microaggregation(x, k=3, cost="sae") print(clusters3) # [1 0 0 0 1 0 1] + np.testing.assert_array_equal(clusters, [1, 0, 0, 0, 1, 0, 1], f"k={3}") + np.testing.assert_array_equal(clusters2, [1, 0, 0, 0, 1, 0, 1], f"k={3}") np.testing.assert_array_equal(clusters3, [1, 0, 0, 0, 1, 0, 1], f"k={3}") +class TestMainNonCompiled(TestMain): + def setUp(self): + self.cleanup = remove_from_class( + self.__class__.__bases__[0], allowed_packages=["microagg1d"] + ) + + def tearDown(self) -> None: + restore_to_class(self.cleanup) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_other_cost.py b/tests/test_other_cost.py new file mode 100644 index 0000000..0a71ca1 --- /dev/null +++ b/tests/test_other_cost.py @@ -0,0 +1,72 @@ +import unittest + +import numpy as np + +from microagg1d.cost_maxdist import AdaptedMaxDistCostCalculator, MaxDistCostCalculator +from microagg1d.cost_round import ( + AdaptedRoundDownCostCalculator, + AdaptedRoundUpCostCalculator, + RoundDownCostCalculator, + RoundUpCostCalculator, +) +from microagg1d.cost_sae import ( + AdaptedSAECostCalculator, + SAECostCalculator, + calc_sorted_median, +) +from microagg1d.utils_for_test import remove_from_class, restore_to_class + +x = np.arange(10, dtype=np.float64) +F_vals = np.zeros(10) + + +invalid_val = np.inf + + +class BasicTests(unittest.TestCase): + def test_invalid_value(self): + sae_calculators = [ + AdaptedSAECostCalculator(x, 2, F_vals), + SAECostCalculator(x), + RoundDownCostCalculator(x), + AdaptedRoundUpCostCalculator(x, 2, F_vals), + RoundUpCostCalculator(x), + AdaptedRoundDownCostCalculator(x, 2, F_vals), + MaxDistCostCalculator(x), + AdaptedMaxDistCostCalculator(x, 2, F_vals), + ] + for calculator in sae_calculators: + self.assertAlmostEqual(calculator.calc(10, 1), invalid_val) + self.assertAlmostEqual(calculator.calc(1, 0), invalid_val) + + +class BasicTestsNonCompiled(BasicTests): + def setUp(self): + self.cleanup = remove_from_class( + self.__class__.__bases__[0], allowed_packages=["microagg1d"] + ) + + def tearDown(self) -> None: + restore_to_class(self.cleanup) + + +class SAETests(unittest.TestCase): + def test_sorted_median(self): + self.assertAlmostEqual(calc_sorted_median(x), 4.5) + self.assertAlmostEqual(calc_sorted_median(x, lb=1), 5) + # one element case + self.assertAlmostEqual(calc_sorted_median(x, ub=1), 0) + + +class SAETestsNonCompiled(SAETests): + def setUp(self): + self.cleanup = remove_from_class( + self.__class__.__bases__[0], allowed_packages=["microagg1d"] + ) + + def tearDown(self) -> None: + restore_to_class(self.cleanup) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_smwak.py b/tests/test_smwak.py index d0749c5..d4ca26b 100644 --- a/tests/test_smwak.py +++ b/tests/test_smwak.py @@ -49,6 +49,10 @@ def test_smawk_array_1(self): solution = np.argmin(test_arr, axis=1) assert_array_equal(solution, smawk_iter_array(test_arr)) + def test_smawk_no_input(self): + smawk_iter_array(np.empty((0, 10))) + smawk_iter_array(np.empty((10, 0))) + class TestSMAWKIterNonCompiled(TestSMAWKIter): def setUp(self):