Skip to content

Commit

Permalink
(pytest) cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Anastasios Zouzias committed Nov 17, 2024
1 parent c17007e commit 146ea88
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python-package/tests/test_bin_class_titanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def load_titanic():
""" Load Titanic dataset """
"""Load Titanic dataset"""

pwd = os.path.dirname(os.path.abspath(__file__))
df = pd.read_csv(os.path.join(pwd, "../../data/titanic.csv"))
Expand Down
16 changes: 8 additions & 8 deletions python-package/tests/test_regressor_boston.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.metrics import root_mean_squared_error

from sklearn.datasets import load_boston
from sklearn.datasets import fetch_california_housing


RANDOM_SEED = 123
Expand All @@ -24,10 +24,10 @@
}


def _load_boston():
""" Load Boston regression dataset """
def _load_california_housing():
"""Load California housing regression dataset"""

data, target = load_boston(return_X_y=True)
data, target = fetch_california_housing(return_X_y=True)

print("Input dataset dimensions {}".format(data.shape))
print("Target dims: {}".format(target.shape))
Expand All @@ -42,11 +42,11 @@ def _load_boston():
return X_train, X_valid, y_train, y_valid


def test_microgbt_boston_rmse():
def test_microgbt_housing_rmse():
num_iters = 100
early_stopping_rounds = 10

X_train, X_valid, y_train, y_valid = _load_boston()
X_train, X_valid, y_train, y_valid = _load_california_housing()

# Train
gbt = microgbtpy.GBT(params)
Expand All @@ -57,4 +57,4 @@ def test_microgbt_boston_rmse():
for x in X_valid:
y_valid_preds.append(gbt.predict(x, gbt.best_iteration()))

assert mean_squared_error(y_valid, y_valid_preds, squared=False) < 10.0
assert root_mean_squared_error(y_valid, y_valid_preds) < 4

0 comments on commit 146ea88

Please sign in to comment.