Skip to content

Commit

Permalink
limit time and memory consumption (#264)
Browse files Browse the repository at this point in the history
* limit time and memory

* separate tests

* lrl1 can't be limited by limit_resource

* free memory when possible

* passthrough=False when ensemble fails;
retrain when trained_estimator is None

* use callback to for resource limit

* handle lower version of xgb with no callback

* free mem ratio

* reduce verbosity

* retrain_final when max_iter==1

* remove trained_estimator from result

* model_history

* wheel

* retrain time as best_config_train_time

* ci: libomp version for xgboost on macos

* limit_resource not working in windows

* test pickle load

* mute forecaster

* notebook update

* check hard

* preventive callback

* add use_ray
  • Loading branch information
sonichi authored Nov 4, 2021
1 parent 6c66cd6 commit 549a0df
Show file tree
Hide file tree
Showing 12 changed files with 1,731 additions and 1,376 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: If mac, install libomp to facilitate lgbm install
- name: If mac, install libomp to facilitate lgbm and xgboost install
if: matrix.os == 'macOS-latest'
run: |
# remove libomp version constraint after xgboost works with libomp>11.1.0
wget https://raw.githubusercontent.com/Homebrew/homebrew-core/679923b4eb48a8dc7ecc1f05d06063cd79b3fc00/Formula/libomp.rb -O $(find $(brew --repository) -name libomp.rb)
brew install libomp
export CC=/usr/bin/clang
export CXX=/usr/bin/clang++
Expand All @@ -36,7 +38,7 @@ jobs:
export LDFLAGS="$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp"
- name: Install packages and dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip wheel
pip install -e .[test]
- name: If linux or mac, install ray
if: (matrix.os == 'macOS-latest' || matrix.os == 'ubuntu-latest') && matrix.python-version != '3.9'
Expand Down Expand Up @@ -65,7 +67,7 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

docs:

runs-on: ubuntu-latest
Expand Down
150 changes: 113 additions & 37 deletions flaml/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _compute_with_config_base(self, estimator, config_w_resource):
"wall_clock_time": time.time() - self._start_time_flag,
"metric_for_logging": metric_for_logging,
"val_loss": val_loss,
"trained_estimator": trained_estimator,
"trained_estimator": trained_estimator if self.save_model_history else None,
}
if sampled_weight is not None:
self.fit_kwargs["sample_weight"] = weight
Expand Down Expand Up @@ -403,9 +403,10 @@ def best_loss(self):

@property
def best_config_train_time(self):
"""A float of the seconds taken by training the
best config."""
return self._search_states[self._best_estimator].best_config_train_time
"""A float of the seconds taken by training the best config."""
return getattr(
self._search_states[self._best_estimator], "best_config_train_time", None
)

@property
def classes_(self):
Expand Down Expand Up @@ -529,8 +530,9 @@ def _validate_data(
self._nrow, self._ndim = X_train_all.shape
if self._state.task == TS_FORECAST:
X_train_all = pd.DataFrame(X_train_all)
assert X_train_all[X_train_all.columns[0]].dtype.name == 'datetime64[ns]', (
f"For '{TS_FORECAST}' task, the first column must contain timestamp values.")
assert (
X_train_all[X_train_all.columns[0]].dtype.name == "datetime64[ns]"
), f"For '{TS_FORECAST}' task, the first column must contain timestamp values."
X, y = X_train_all, y_train_all
elif dataframe is not None and label is not None:
assert isinstance(
Expand All @@ -539,8 +541,9 @@ def _validate_data(
assert label in dataframe.columns, "label must a column name in dataframe"
self._df = True
if self._state.task == TS_FORECAST:
assert dataframe[dataframe.columns[0]].dtype.name == 'datetime64[ns]', (
f"For '{TS_FORECAST}' task, the first column must contain timestamp values.")
assert (
dataframe[dataframe.columns[0]].dtype.name == "datetime64[ns]"
), f"For '{TS_FORECAST}' task, the first column must contain timestamp values."
X = dataframe.drop(columns=label)
self._nrow, self._ndim = X.shape
y = dataframe[label]
Expand Down Expand Up @@ -584,7 +587,9 @@ def _validate_data(
else:
self._state.X_val = X_val
if self._label_transformer:
self._state.y_val = self._label_transformer.transform(y_val, self._state.task)
self._state.y_val = self._label_transformer.transform(
y_val, self._state.task
)
else:
self._state.y_val = y_val
else:
Expand Down Expand Up @@ -1064,7 +1069,8 @@ def _decide_eval_method(self, time_budget):
return "holdout"
nrow, dim = self._nrow, self._ndim
if (
nrow * dim / 0.9 < SMALL_LARGE_THRES * (time_budget / 3600)
time_budget is None
or nrow * dim / 0.9 < SMALL_LARGE_THRES * (time_budget / 3600)
and nrow < CV_HOLDOUT_THRESHOLD
):
# time allows or sampling can be used and cv is necessary
Expand Down Expand Up @@ -1301,6 +1307,7 @@ def fit(
append_log=False,
auto_augment=True,
min_sample_size=MIN_SAMPLE_TRAIN,
use_ray=False,
**fit_kwargs,
):
"""Find a model for a given task
Expand Down Expand Up @@ -1414,7 +1421,9 @@ def custom_metric(
In the following code example, we get starting_points from the
automl_experiment and use them in the new_automl_experiment.
e.g.,
.. code-block:: python
from flaml import AutoML
automl_experiment = AutoML()
X_train, y_train = load_iris(return_X_y=True)
Expand All @@ -1440,6 +1449,10 @@ def custom_metric(
augment rare classes.
min_sample_size: int, default=MIN_SAMPLE_TRAIN | the minimal sample
size when sample=True.
use_ray: boolean, default=False | Whether to use ray to run the training
in separate processes. This can be used to prevent OOM for large
datasets, but will incur more overhead in time. Only use it if
you run into OOM failures.
**fit_kwargs: Other key word arguments to pass to fit() function of
the searched learners, such as sample_weight. Include period as
a key word argument for 'ts_forecast' task.
Expand Down Expand Up @@ -1483,8 +1496,10 @@ def custom_metric(
)
self._retrain_final = (
retrain_full is True
and (eval_method == "holdout" and self._state.X_val is None)
or (eval_method == "cv")
and eval_method == "holdout"
and self._state.X_val is None
or eval_method == "cv"
or max_iter == 1
)
self._auto_augment = auto_augment
self._min_sample_size = min_sample_size
Expand Down Expand Up @@ -1564,7 +1579,7 @@ def custom_metric(
logger.info("List of ML learners in AutoML Run: {}".format(estimator_list))
self.estimator_list = estimator_list
self._hpo_method = hpo_method or ("cfo" if n_concurrent_trials == 1 else "bs")
self._state.time_budget = time_budget
self._state.time_budget = time_budget or 1e10
self._active_estimators = estimator_list.copy()
self._ensemble = ensemble
self._max_iter = max_iter
Expand All @@ -1573,10 +1588,11 @@ def custom_metric(
self._state.train_time_limit = train_time_limit
self._log_type = log_type
self.split_ratio = split_ratio
self._save_model_history = model_history
self._state.save_model_history = model_history
self._state.n_jobs = n_jobs
self._n_concurrent_trials = n_concurrent_trials
self._early_stop = early_stop
self._use_ray = use_ray or self._n_concurrent_trials > 1
if log_file_name:
with training_log_writer(log_file_name, append_log) as save_helper:
self._training_log = save_helper
Expand Down Expand Up @@ -1627,7 +1643,7 @@ def _search_parallel(self):
from ray.tune.suggest import ConcurrencyLimiter
except (ImportError, AssertionError):
raise ImportError(
"n_concurrent_trial > 1 requires installation of ray. "
"n_concurrent_trial>1 or use_ray=True requires installation of ray. "
"Please run pip install flaml[ray]"
)
if self._hpo_method in ("cfo", "grid"):
Expand Down Expand Up @@ -1693,7 +1709,8 @@ def _search_parallel(self):
resources_per_trial=resources_per_trial,
time_budget_s=self._state.time_budget,
num_samples=self._max_iter,
verbose=self.verbose,
verbose=max(self.verbose - 3, 0),
raise_on_failed_trial=False,
)
# logger.info([trial.last_result for trial in analysis.trials])
trials = sorted(
Expand All @@ -1712,7 +1729,7 @@ def _search_parallel(self):
config = result["config"]
estimator = config.get("ml", config)["learner"]
search_state = self._search_states[estimator]
search_state.update(result, 0, self._save_model_history)
search_state.update(result, 0, self._state.save_model_history)
if result["wall_clock_time"] is not None:
self._state.time_from_start = result["wall_clock_time"]
if search_state.sample_size == self._state.data_size:
Expand All @@ -1727,7 +1744,7 @@ def _search_parallel(self):
config,
self._time_taken_best_iter,
)
if self._save_model_history:
if self._state.save_model_history:
self._model_history[
_track_iter
] = search_state.trained_estimator
Expand Down Expand Up @@ -1902,7 +1919,7 @@ def _search_sequential(self):
search_state.update(
result,
time_used=time_used,
save_model_history=self._save_model_history,
save_model_history=self._state.save_model_history,
)
if self._estimator_index is None:
# update init eci estimate
Expand Down Expand Up @@ -1945,18 +1962,27 @@ def _search_sequential(self):
search_state.best_config,
self._state.time_from_start,
)
if self._save_model_history:
if self._state.save_model_history:
self._model_history[
self._track_iter
] = search_state.trained_estimator
elif self._trained_estimator:
del self._trained_estimator
self._trained_estimator = None
self._trained_estimator = search_state.trained_estimator
if not self._retrain_final:
self._trained_estimator = search_state.trained_estimator
self._best_iteration = self._track_iter
self._time_taken_best_iter = self._state.time_from_start
better = True
next_trial_time = search_state.time2eval_best
if search_state.trained_estimator and not (
self._state.save_model_history or self._ensemble
):
# free RAM
if search_state.trained_estimator != self._trained_estimator:
search_state.trained_estimator.cleanup()
del search_state.trained_estimator
search_state.trained_estimator = None
if better or self._log_type == "all":
if self._training_log:
self._training_log.append(
Expand Down Expand Up @@ -2049,7 +2075,9 @@ def _search_sequential(self):
logger.info(
"retrain {} for {:.1f}s".format(self._best_estimator, retrain_time)
)
self._retrained_config[best_config_sig] = retrain_time
self._retrained_config[
best_config_sig
] = state.best_config_train_time = retrain_time
est_retrain_time = 0
self._state.time_from_start = time.time() - self._start_time_flag
if (
Expand Down Expand Up @@ -2083,7 +2111,7 @@ def _search(self):
self._selected = None
self.modelcount = 0

if self._n_concurrent_trials == 1:
if not self._use_ray:
self._search_sequential()
else:
self._search_parallel()
Expand All @@ -2103,12 +2131,29 @@ def _search(self):
"regression",
):
search_states = list(
x for x in self._search_states.items() if x[1].trained_estimator
x for x in self._search_states.items() if x[1].best_config
)
search_states.sort(key=lambda x: x[1].best_loss)
estimators = [(x[0], x[1].trained_estimator) for x in search_states[:2]]
estimators = [
(
x[0],
x[1].learner_class(
task=self._state.task,
n_jobs=self._state.n_jobs,
**x[1].best_config,
),
)
for x in search_states[:2]
]
estimators += [
(x[0], x[1].trained_estimator)
(
x[0],
x[1].learner_class(
task=self._state.task,
n_jobs=self._state.n_jobs,
**x[1].best_config,
),
)
for x in search_states[2:]
if x[1].best_loss < 4 * self._selected.best_loss
]
Expand All @@ -2135,19 +2180,49 @@ def _search(self):
)
if self._sample_weight_full is not None:
self._state.fit_kwargs["sample_weight"] = self._sample_weight_full
stacker.fit(
self._X_train_all, self._y_train_all, **self._state.fit_kwargs
)
logger.info(f"ensemble: {stacker}")
self._trained_estimator = stacker
self._trained_estimator.model = stacker
for e in estimators:
e[1].__class__.init()
try:
stacker.fit(
self._X_train_all, self._y_train_all, **self._state.fit_kwargs
)
logger.info(f"ensemble: {stacker}")
self._trained_estimator = stacker
self._trained_estimator.model = stacker
except ValueError as e:
if passthrough:
logger.warning(
"Using passthrough=False for ensemble because the data contain categorical features."
)
stacker = Stacker(
estimators,
final_estimator,
n_jobs=self._state.n_jobs,
passthrough=False,
)
stacker.fit(
self._X_train_all,
self._y_train_all,
**self._state.fit_kwargs,
)
logger.info(f"ensemble: {stacker}")
self._trained_estimator = stacker
self._trained_estimator.model = stacker
else:
raise e
elif self._retrain_final:
# reset time budget for retraining
self._state.time_from_start -= self._state.time_budget
if self._state.task == TS_FORECAST or (
self._state.time_budget - self._state.time_from_start
> self._selected.est_retrain_time(self.data_size_full)
and self._selected.best_config_sample_size == self._state.data_size
if self._max_iter > 1:
self._state.time_from_start -= self._state.time_budget
if (
self._state.task == TS_FORECAST
or self._trained_estimator is None
or (
self._state.time_budget - self._state.time_from_start
> self._selected.est_retrain_time(self.data_size_full)
and self._selected.best_config_sample_size
== self._state.data_size
)
):
state = self._search_states[self._best_estimator]
(
Expand All @@ -2163,6 +2238,7 @@ def _search(self):
self._best_estimator, retrain_time
)
)
state.best_config_train_time = retrain_time
if self._trained_estimator:
logger.info(f"retrained model: {self._trained_estimator.model}")
else:
Expand Down
5 changes: 2 additions & 3 deletions flaml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ def fit_transform(self, X, y, task):
X[column] = X[column].map(datetime.toordinal)
datetime_columns.append(column)
del tmp_dt
else:
X[column] = X[column].fillna(np.nan)
num_columns.append(column)
X[column] = X[column].fillna(np.nan)
num_columns.append(column)
X = X[cat_columns + num_columns]
if task == TS_FORECAST:
X.insert(0, TS_TIMESTAMP_COL, ds_col)
Expand Down
Loading

0 comments on commit 549a0df

Please sign in to comment.