diff --git a/amlb/data.py b/amlb/data.py index acca17841..9add33341 100644 --- a/amlb/data.py +++ b/amlb/data.py @@ -10,10 +10,12 @@ which can also be encoded (``y_enc``, ``X_enc``) - **Feature** provides metadata for a given feature/column as well as encoding functions. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from enum import Enum, auto +from enum import Enum import logging -from typing import List, Union +from typing import List, Union, Iterable import numpy as np import pandas as pd @@ -31,7 +33,7 @@ class Feature: - def __init__(self, index, name, data_type, values=None, has_missing_values=False, is_target=False): + def __init__(self, index: int, name: str, data_type: str | None, values: Iterable[str] | None = None, has_missing_values: bool = False, is_target: bool = False): """ :param index: index of the feature in the full data frame. :param name: name of the feature. @@ -43,64 +45,63 @@ def __init__(self, index, name, data_type, values=None, has_missing_values=False self.index = index self.name = name self.data_type = data_type.lower() if data_type is not None else None - self.values = values + self.values = values # type: ignore # https://github.com/python/mypy/issues/3004 self.has_missing_values = has_missing_values self.is_target = is_target - # print(self) - def is_categorical(self, strict=True): + def is_categorical(self, strict: bool = True) -> bool: if strict: return self.data_type == 'category' - else: - return self.data_type is not None and not self.is_numerical() + return self.data_type is not None and not self.is_numerical() - def is_numerical(self): + def is_numerical(self) -> bool: return self.data_type in ['int', 'float', 'number'] @lazy_property - def label_encoder(self): + def label_encoder(self) -> Encoder: return Encoder('label' if self.values is not None else 'no-op', target=self.is_target, encoded_type=int if self.is_target and not self.is_numerical() else float, missing_values=[None, np.nan, pd.NA], missing_policy='mask' if self.has_missing_values else 'ignore', - normalize_fn=self.normalize + normalize_fn=Feature.normalize ).fit(self.values) @lazy_property - def one_hot_encoder(self): + def one_hot_encoder(self) -> Encoder: return Encoder('one-hot' if self.values is not None else 'no-op', target=self.is_target, encoded_type=int if self.is_target and not self.is_numerical() else float, missing_values=[None, np.nan, pd.NA], missing_policy='mask' if self.has_missing_values else 'ignore', - normalize_fn=self.normalize + normalize_fn=Feature.normalize ).fit(self.values) - def normalize(self, arr): + @staticmethod + def normalize(arr: Iterable[str]) -> np.ndarray: return np.char.lower(np.char.strip(np.asarray(arr).astype(str))) @property - def values(self): + def values(self) -> list[str] | None: return self._values @values.setter - def values(self, values): - self._values = self.normalize(values).tolist() if values is not None else None + def values(self, values: Iterable[str]) -> None: + self._values = Feature.normalize(values).tolist() if values is not None else None - def __repr__(self): + def __repr__(self) -> str: return repr_def(self, 'all') class Datasplit(ABC): - def __init__(self, dataset, format): + def __init__(self, dataset: Dataset, file_format: str): """ - :param format: the default format of the data file, obtained through the 'path' property. + :param file_format: the default format of the data file, obtained through the 'path' property. """ super().__init__() self.dataset = dataset - self.format = format + self.format = file_format @property def path(self) -> str: @@ -137,7 +138,7 @@ def y(self) -> DF: """ :return:the target column as a pandas DataFrame: if you need a Series, just call `y.squeeze()`. """ - return self.data.iloc[:, [self.dataset.target.index]] + return self.data.iloc[:, [self.dataset.target.index]] # type: ignore @lazy_property @profile(logger=log) @@ -164,7 +165,7 @@ def y_enc(self) -> AM: return self.data_enc[:, self.dataset.target.index] @profile(logger=log) - def release(self, properties=None): + def release(self, properties: Iterable[str] | None = None) -> None: clear_cache(self, properties) @@ -177,7 +178,7 @@ class DatasetType(Enum): class Dataset(ABC): - def __init__(self): + def __init__(self) -> None: super().__init__() @property @@ -228,11 +229,10 @@ def target(self) -> Feature: pass @profile(logger=log) - def release(self, properties=None): + def release(self) -> None: """ Call this to release cached properties and optimize memory once in-memory data are not needed anymore. - :param properties: """ self.train.release() self.test.release() - clear_cache(self, properties) + clear_cache(self) diff --git a/amlb/datasets/file.py b/amlb/datasets/file.py index abc51fc82..54568f031 100644 --- a/amlb/datasets/file.py +++ b/amlb/datasets/file.py @@ -199,10 +199,10 @@ def __repr__(self): class FileDatasplit(Datasplit): - def __init__(self, dataset: FileDataset, format: str, path: str): - super().__init__(dataset, format) + def __init__(self, dataset: FileDataset, file_format: str, path: str): + super().__init__(dataset, file_format) self._path = path - self._data = {format: path} + self._data = {file_format: path} def data_path(self, format): supported_formats = [cls.format for cls in __file_converters__] @@ -267,7 +267,7 @@ def __init__(self, train_path, test_path, class ArffDatasplit(FileDatasplit): def __init__(self, dataset, path): - super().__init__(dataset, format='arff', path=path) + super().__init__(dataset, file_format='arff', path=path) self._ds = None def _ensure_loaded(self): @@ -419,7 +419,7 @@ def compute_seasonal_error(self): class CsvDatasplit(FileDatasplit): def __init__(self, dataset, path, timestamp_column=None): - super().__init__(dataset, format='csv', path=path) + super().__init__(dataset, file_format='csv', path=path) self._ds = None self.timestamp_column = timestamp_column