Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations for data module #634

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions amlb/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
which can also be encoded (``y_enc``, ``X_enc``)
- **Feature** provides metadata for a given feature/column as well as encoding functions.
"""
from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum, auto
from enum import Enum
import logging
from typing import List, Union
from typing import List, Union, Iterable

import numpy as np
import pandas as pd
Expand All @@ -31,7 +33,7 @@

class Feature:

def __init__(self, index, name, data_type, values=None, has_missing_values=False, is_target=False):
def __init__(self, index: int, name: str, data_type: str | None, values: Iterable[str] | None = None, has_missing_values: bool = False, is_target: bool = False):
"""
:param index: index of the feature in the full data frame.
:param name: name of the feature.
Expand All @@ -43,64 +45,63 @@ def __init__(self, index, name, data_type, values=None, has_missing_values=False
self.index = index
self.name = name
self.data_type = data_type.lower() if data_type is not None else None
self.values = values
self.values = values # type: ignore # https://github.com/python/mypy/issues/3004
self.has_missing_values = has_missing_values
self.is_target = is_target
# print(self)

def is_categorical(self, strict=True):
def is_categorical(self, strict: bool = True) -> bool:
if strict:
return self.data_type == 'category'
else:
return self.data_type is not None and not self.is_numerical()
return self.data_type is not None and not self.is_numerical()

def is_numerical(self):
def is_numerical(self) -> bool:
return self.data_type in ['int', 'float', 'number']

@lazy_property
def label_encoder(self):
def label_encoder(self) -> Encoder:
return Encoder('label' if self.values is not None else 'no-op',
target=self.is_target,
encoded_type=int if self.is_target and not self.is_numerical() else float,
missing_values=[None, np.nan, pd.NA],
missing_policy='mask' if self.has_missing_values else 'ignore',
normalize_fn=self.normalize
normalize_fn=Feature.normalize
).fit(self.values)

@lazy_property
def one_hot_encoder(self):
def one_hot_encoder(self) -> Encoder:
return Encoder('one-hot' if self.values is not None else 'no-op',
target=self.is_target,
encoded_type=int if self.is_target and not self.is_numerical() else float,
missing_values=[None, np.nan, pd.NA],
missing_policy='mask' if self.has_missing_values else 'ignore',
normalize_fn=self.normalize
normalize_fn=Feature.normalize
).fit(self.values)

def normalize(self, arr):
@staticmethod
def normalize(arr: Iterable[str]) -> np.ndarray:
return np.char.lower(np.char.strip(np.asarray(arr).astype(str)))

@property
def values(self):
def values(self) -> list[str] | None:
return self._values

@values.setter
def values(self, values):
self._values = self.normalize(values).tolist() if values is not None else None
def values(self, values: Iterable[str]) -> None:
self._values = Feature.normalize(values).tolist() if values is not None else None

def __repr__(self):
def __repr__(self) -> str:
return repr_def(self, 'all')


class Datasplit(ABC):

def __init__(self, dataset, format):
def __init__(self, dataset: Dataset, file_format: str):
"""
:param format: the default format of the data file, obtained through the 'path' property.
:param file_format: the default format of the data file, obtained through the 'path' property.
"""
super().__init__()
self.dataset = dataset
self.format = format
self.format = file_format

@property
def path(self) -> str:
Expand Down Expand Up @@ -137,7 +138,7 @@ def y(self) -> DF:
"""
:return:the target column as a pandas DataFrame: if you need a Series, just call `y.squeeze()`.
"""
return self.data.iloc[:, [self.dataset.target.index]]
return self.data.iloc[:, [self.dataset.target.index]] # type: ignore

@lazy_property
@profile(logger=log)
Expand All @@ -164,7 +165,7 @@ def y_enc(self) -> AM:
return self.data_enc[:, self.dataset.target.index]

@profile(logger=log)
def release(self, properties=None):
def release(self, properties: Iterable[str] | None = None) -> None:
clear_cache(self, properties)


Expand All @@ -177,7 +178,7 @@ class DatasetType(Enum):

class Dataset(ABC):

def __init__(self):
def __init__(self) -> None:
super().__init__()

@property
Expand Down Expand Up @@ -228,11 +229,10 @@ def target(self) -> Feature:
pass

@profile(logger=log)
def release(self, properties=None):
def release(self) -> None:
"""
Call this to release cached properties and optimize memory once in-memory data are not needed anymore.
:param properties:
"""
self.train.release()
self.test.release()
clear_cache(self, properties)
clear_cache(self)
10 changes: 5 additions & 5 deletions amlb/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def __repr__(self):

class FileDatasplit(Datasplit):

def __init__(self, dataset: FileDataset, format: str, path: str):
super().__init__(dataset, format)
def __init__(self, dataset: FileDataset, file_format: str, path: str):
super().__init__(dataset, file_format)
self._path = path
self._data = {format: path}
self._data = {file_format: path}

def data_path(self, format):
supported_formats = [cls.format for cls in __file_converters__]
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(self, train_path, test_path,
class ArffDatasplit(FileDatasplit):

def __init__(self, dataset, path):
super().__init__(dataset, format='arff', path=path)
super().__init__(dataset, file_format='arff', path=path)
self._ds = None

def _ensure_loaded(self):
Expand Down Expand Up @@ -419,7 +419,7 @@ def compute_seasonal_error(self):
class CsvDatasplit(FileDatasplit):

def __init__(self, dataset, path, timestamp_column=None):
super().__init__(dataset, format='csv', path=path)
super().__init__(dataset, file_format='csv', path=path)
self._ds = None
self.timestamp_column = timestamp_column

Expand Down
Loading