Skip to content

Commit

Permalink
use default hparams when num_iters is 1
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed Aug 2, 2024
1 parent bb7f720 commit 9884a6b
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions rtfm/tree_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ def tune_catboost(
else:
raise ValueError(f"got unexpected number of classes: {n_classes}")

if len(X) > 1 and (all(y.value_counts() > cv)) and n_iter > 1:
# Define the model
if n_iter > 1 and len(X) > 1 and all(y.value_counts() > cv):
# Define the model for hyperparameter tuning
model = CatBoostClassifier(
iterations=catboost_iterations,
early_stopping_rounds=catboost_early_stopping_rounds,
random_state=42,
verbose=1,
task_type=task_type, # Use GPU if available
devices="0", # Use first available GPU
task_type=task_type,
devices="0",
eval_metric=eval_metric,
loss_function=loss_function,
)
Expand All @@ -224,14 +224,15 @@ def tune_catboost(
plot=False,
)
else:
# Use default hyperparameters when n_iter is 1 or other conditions are not met
model = CatBoostClassifier(
iterations=catboost_iterations,
early_stopping_rounds=catboost_early_stopping_rounds,
random_state=42,
verbose=1,
cat_features=cat_features,
task_type=task_type, # Use GPU if available
devices="0", # Use first available GPU
task_type=task_type,
devices="0",
eval_metric=eval_metric,
loss_function=loss_function,
)
Expand Down

0 comments on commit 9884a6b

Please sign in to comment.