Skip to content

Commit

Permalink
add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvossler18 committed Jul 10, 2024
1 parent b8210b7 commit 9c6df5e
Show file tree
Hide file tree
Showing 16 changed files with 479 additions and 213 deletions.
10 changes: 8 additions & 2 deletions odtlearn/constrained_oct.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ class ConstrainedOCT(FlowOCTMultipleSink):
"""

def __init__(
self, solver, _lambda, depth, time_limit, num_threads, verbose
self,
solver: str,
_lambda: float,
depth: int,
time_limit: int,
num_threads: None,
verbose: bool,
) -> None:

super().__init__(solver, _lambda, depth, time_limit, num_threads, verbose)
Expand All @@ -34,6 +40,6 @@ def __init__(
def _define_side_constraints(self):
pass

def _define_constraints(self):
def _define_constraints(self) -> None:
super()._define_constraints()
self._define_side_constraints()
163 changes: 96 additions & 67 deletions odtlearn/fair_oct.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import warnings
from itertools import combinations
from typing import Union

import numpy as np
import pandas as pd
from numpy import ndarray
from pandas.core.frame import DataFrame
from pandas.core.series import Series
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

Expand Down Expand Up @@ -94,15 +98,15 @@ class FairConstrainedOCT(ConstrainedOCT):

def __init__(
self,
solver,
positive_class,
_lambda,
obj_mode,
fairness_bound,
depth,
time_limit,
num_threads,
verbose,
solver: str,
positive_class: int,
_lambda: float,
obj_mode: str,
fairness_bound: float,
depth: int,
time_limit: int,
num_threads: None,
verbose: bool,
) -> None:
self._positive_class = positive_class
self._fairness_bound = fairness_bound
Expand All @@ -113,7 +117,13 @@ def __init__(
self.weights = None
super().__init__(solver, _lambda, depth, time_limit, num_threads, verbose)

def _extract_metadata(self, X, y, protect_feat):
def _extract_metadata(
self,
X: Union[ndarray, DataFrame],
y: Union[ndarray, Series],
protect_feat: Union[ndarray, DataFrame],
) -> None:

super(ConstrainedOCT, self)._extract_metadata(X, y)
if isinstance(protect_feat, pd.DataFrame):
self._protect_feat_col_labels = protect_feat.columns
Expand All @@ -123,7 +133,7 @@ def _extract_metadata(self, X, y, protect_feat):
[f"P_{i}" for i in np.arange(0, protect_feat.shape[1])]
)

def _add_fairness_constraint(self, p_df, p_prime_df):
def _add_fairness_constraint(self, p_df: DataFrame, p_prime_df: DataFrame) -> bool:
"""
Add the fairness constraint to the MIP problem.
Expand Down Expand Up @@ -194,7 +204,7 @@ def _add_fairness_constraint(self, p_df, p_prime_df):

return constraint_added

def _define_objective(self):
def _define_objective(self) -> None:
# Max sum(sum(zeta[i,n,y(i)]))
obj = self._solver.lin_expr(0)
for n in self._tree.Nodes:
Expand All @@ -211,7 +221,14 @@ def _define_objective(self):

self._solver.set_objective(obj, ODTL.MAXIMIZE)

def fit(self, X, y, protect_feat, legit_factor, weights=None):
def fit(
self,
X: ndarray,
y: ndarray,
protect_feat: ndarray,
legit_factor: ndarray,
weights: None = None,
) -> Union["FairCSPOCT", "FairSPOCT", "FairEOddsOCT", "FairEOppOCT", "FairPEOCT"]:
"""
Fit the Fair Constrained Optimal Classification Tree (FairConstrainedOCT) model to the given training data.
Expand Down Expand Up @@ -341,7 +358,7 @@ def fit(self, X, y, protect_feat, legit_factor, weights=None):
# Return the classifier
return self

def predict(self, X):
def predict(self, X: Union[DataFrame, ndarray]) -> ndarray:
"""
Predict class labels for samples in X using the fitted Fair Constrained Optimal Classification Tree model.
Expand Down Expand Up @@ -433,15 +450,15 @@ class FairSPOCT(FairConstrainedOCT):

def __init__(
self,
solver,
positive_class,
depth=1,
time_limit=60,
_lambda=0,
obj_mode="acc",
fairness_bound=1,
num_threads=None,
verbose=False,
solver: str,
positive_class: int,
depth: int = 1,
time_limit: int = 60,
_lambda: float = 0,
obj_mode: str = "acc",
fairness_bound: float = 1,
num_threads: Union[None, int] = None,
verbose: bool = False,
) -> None:

super().__init__(
Expand All @@ -456,7 +473,7 @@ def __init__(
verbose,
)

def _define_side_constraints(self):
def _define_side_constraints(self) -> None:
# Loop through all possible combinations of the protected feature
for protected_feature in self._P_col_labels:
for combo in combinations(self._X_p[protected_feature].unique(), 2):
Expand All @@ -467,7 +484,9 @@ def _define_side_constraints(self):
p_prime_df = self._X_p[self._X_p[protected_feature] == p_prime]
self._add_fairness_constraint(p_df, p_prime_df)

def calc_metric(self, protect_feat, y):
def calc_metric(
self, protect_feat: Union[DataFrame, ndarray], y: Union[Series, ndarray]
):
"""
Calculate the statistical parity metric for the given data.
Expand Down Expand Up @@ -553,15 +572,15 @@ class FairCSPOCT(FairConstrainedOCT):

def __init__(
self,
solver,
positive_class,
depth=1,
time_limit=60,
_lambda=0,
obj_mode="acc",
fairness_bound=1,
num_threads=None,
verbose=False,
solver: str,
positive_class: int,
depth: int = 1,
time_limit: int = 60,
_lambda: float = 0,
obj_mode: str = "acc",
fairness_bound: float = 1,
num_threads: Union[None, int] = None,
verbose: bool = False,
) -> None:

super().__init__(
Expand All @@ -576,7 +595,7 @@ def __init__(
verbose,
)

def _define_side_constraints(self):
def _define_side_constraints(self) -> None:
# Loop through all possible combinations of the protected feature
for protected_feature in self._P_col_labels:
for combo in combinations(self._X_p[protected_feature].unique(), 2):
Expand All @@ -593,7 +612,12 @@ def _define_side_constraints(self):
]
self._add_fairness_constraint(p_df, p_prime_df)

def calc_metric(self, protect_feat, legit_factor, y):
def calc_metric(
self,
protect_feat: Union[DataFrame, ndarray],
legit_factor: Union[DataFrame, ndarray],
y: Union[Series, ndarray],
):
"""
Calculate the conditional statistical parity metric for the given data.
Expand Down Expand Up @@ -697,15 +721,15 @@ class FairPEOCT(FairConstrainedOCT):

def __init__(
self,
solver,
positive_class,
depth=1,
time_limit=60,
_lambda=0,
obj_mode="acc",
fairness_bound=1,
num_threads=None,
verbose=False,
solver: str,
positive_class: int,
depth: int = 1,
time_limit: int = 60,
_lambda: float = 0,
obj_mode: str = "acc",
fairness_bound: float = 1,
num_threads: Union[None, int] = None,
verbose: bool = False,
) -> None:

super().__init__(
Expand All @@ -720,7 +744,7 @@ def __init__(
verbose,
)

def _define_side_constraints(self):
def _define_side_constraints(self) -> None:
# Loop through all possible combinations of the protected feature
for protected_feature in self._P_col_labels:
for combo in combinations(self._X_p[protected_feature].unique(), 2):
Expand All @@ -736,7 +760,12 @@ def _define_side_constraints(self):
]
self._add_fairness_constraint(p_df, p_prime_df)

def calc_metric(self, protect_feat, y, y_pred):
def calc_metric(
self,
protect_feat: Union[DataFrame, ndarray],
y: Union[Series, ndarray],
y_pred: Union[Series, ndarray],
):
"""
Calculate the predictive equality metric for the given data.
Expand Down Expand Up @@ -841,15 +870,15 @@ class FairEOppOCT(FairConstrainedOCT):

def __init__(
self,
solver,
positive_class,
depth=1,
time_limit=60,
_lambda=0,
obj_mode="acc",
fairness_bound=1,
num_threads=None,
verbose=False,
solver: str,
positive_class: int,
depth: int = 1,
time_limit: int = 60,
_lambda: float = 0,
obj_mode: str = "acc",
fairness_bound: float = 1,
num_threads: Union[None, int] = None,
verbose: bool = False,
) -> None:

super().__init__(
Expand All @@ -864,7 +893,7 @@ def __init__(
verbose,
)

def _define_side_constraints(self):
def _define_side_constraints(self) -> None:
# Loop through all possible combinations of the protected feature
for protected_feature in self._P_col_labels:
for combo in combinations(self._X_p[protected_feature].unique(), 2):
Expand Down Expand Up @@ -918,15 +947,15 @@ class FairEOddsOCT(FairConstrainedOCT):

def __init__(
self,
solver,
positive_class,
depth=1,
time_limit=60,
_lambda=0,
obj_mode="acc",
fairness_bound=1,
num_threads=None,
verbose=False,
solver: str,
positive_class: int,
depth: int = 1,
time_limit: int = 60,
_lambda: float = 0,
obj_mode: str = "acc",
fairness_bound: float = 1,
num_threads: Union[None, int] = None,
verbose: bool = False,
) -> None:

super().__init__(
Expand All @@ -941,7 +970,7 @@ def __init__(
verbose,
)

def _define_side_constraints(self):
def _define_side_constraints(self) -> None:
# Loop through all possible combinations of the protected feature
for protected_feature in self._P_col_labels:
for combo in combinations(self._X_p[protected_feature].unique(), 2):
Expand Down
Loading

0 comments on commit 9c6df5e

Please sign in to comment.