From 7c7c717200726c28adf9465cf093c98aec374cff Mon Sep 17 00:00:00 2001 From: Tristan Crockett Date: Mon, 20 Feb 2017 15:56:51 -0600 Subject: [PATCH] Allow yielding of trained models [Resolves #11] - Implement ModelTrainer.generate_trained_models as a generator interface for grid training - Implement Store.delete() for enable projects to clean up their stored models, either from storage or the references for garbage collection - Implement Predictor.delete_model to make Store.delete() available from the model_id rather than the model hash (future: implement id-awareness into storage, or return hashes alongside ids from trainer?) --- tests/test_model_trainers.py | 8 +++++++ tests/test_predictors.py | 4 ++++ tests/test_storage.py | 18 +++++++++++++-- triage/model_trainers.py | 43 +++++++++++++++++++++++++++++------- triage/predictors.py | 23 ++++++++++++++++--- triage/storage.py | 9 ++++++++ 6 files changed, 92 insertions(+), 13 deletions(-) diff --git a/tests/test_model_trainers.py b/tests/test_model_trainers.py index 12f7eb5a8..7aa84c174 100644 --- a/tests/test_model_trainers.py +++ b/tests/test_model_trainers.py @@ -122,3 +122,11 @@ def test_model_trainer(): engine.execute('select * from results.feature_importances') ] assert len(records) == 4 * 3 # maybe exclude entity_id? + + # 7. that the generator interface works the same way + new_model_ids = trainer.generate_trained_models( + grid_config=grid_config, + misc_db_parameters=dict() + ) + assert expected_model_ids == \ + sorted([model_id for model_id in new_model_ids]) diff --git a/tests/test_predictors.py b/tests/test_predictors.py index 14b44ca57..197c178db 100644 --- a/tests/test_predictors.py +++ b/tests/test_predictors.py @@ -86,3 +86,7 @@ def test_predictor(): join results.models using (model_id)''') ] assert len(records) == 4 + + # 6. That we can delete the model when done prediction on it + predictor.delete_model(model_id) + assert predictor.load_model(model_id) == None diff --git a/tests/test_storage.py b/tests/test_storage.py index ecbe5d2ca..3f2a3af8c 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,4 +1,4 @@ -from triage.storage import S3Store, FSStore +from triage.storage import S3Store, FSStore, MemoryStore from moto import mock_s3 import boto3 import os @@ -19,6 +19,8 @@ def test_S3Store(): assert store.exists() newVal = store.load() assert newVal.val == 'val' + store.delete() + assert not store.exists() def test_FSStore(): @@ -30,4 +32,16 @@ def test_FSStore(): assert store.exists() newVal = store.load() assert newVal.val == 'val' - os.remove('tmpfile') + store.delete() + assert not store.exists() + + +def test_MemoryStore(): + store = MemoryStore(None) + assert not store.exists() + store.write(SomeClass('val')) + assert store.exists() + newVal = store.load() + assert newVal.val == 'val' + store.delete() + assert not store.exists() diff --git a/triage/model_trainers.py b/triage/model_trainers.py index 22768975c..bad428e09 100644 --- a/triage/model_trainers.py +++ b/triage/model_trainers.py @@ -267,13 +267,13 @@ def _get_model_group_id( logging.debug('Model_group_id = {}'.format(model_group_id)) return model_group_id - def train_models( + def generate_trained_models( self, grid_config, misc_db_parameters, replace=False ): - """Train and store configured models + """Train and store configured models, yielding the ids one by one Args: grid_config (dict) of format {classpath: hyperparameter dicts} @@ -286,10 +286,8 @@ def train_models( misc_db_parameters (dict) params to pass through to the database replace (optional, False): whether to replace already cached models - Returns: - (list) of model ids + Yields: (int) model ids """ - model_ids = [] misc_db_parameters = copy.deepcopy(misc_db_parameters) misc_db_parameters['batch_run_time'] = datetime.datetime.now().isoformat() for class_path, parameters in self._generate_model_configs(grid_config): @@ -304,7 +302,7 @@ def train_models( model_store, misc_db_parameters ) - model_ids.append(model_id) + yield model_id else: logging.info('Skipping %s/%s', class_path, parameters) session = self.sessionmaker() @@ -327,6 +325,35 @@ def train_models( ) else: model_id = saved.model_id - model_ids.append(model_id) + yield model_id + + def train_models( + self, + grid_config, + misc_db_parameters, + replace=False + ): + """Train and store configured models + + Args: + grid_config (dict) of format {classpath: hyperparameter dicts} + example: { 'sklearn.ensemble.RandomForestClassifier': { + 'n_estimators': [1,10,100,1000,10000], + 'max_depth': [1,5,10,20,50,100], + 'max_features': ['sqrt','log2'], + 'min_samples_split': [2,5,10] + } } + misc_db_parameters (dict) params to pass through to the database + replace (optional, False): whether to replace already cached models + + Returns: + (list) of model ids + """ + return [ + model_id for model_id in self.generate_trained_models( + grid_config, + misc_db_parameters, + replace + ) + ] - return model_ids diff --git a/triage/predictors.py b/triage/predictors.py index 8c022b84a..9f626f4be 100644 --- a/triage/predictors.py +++ b/triage/predictors.py @@ -4,6 +4,10 @@ import logging +class ModelNotFoundError(ValueError): + pass + + class Predictor(object): def __init__(self, project_path, model_storage_engine, db_engine): """Encapsulates the task of generating predictions on an arbitrary @@ -20,7 +24,7 @@ def __init__(self, project_path, model_storage_engine, db_engine): if self.db_engine: self.sessionmaker = sessionmaker(bind=self.db_engine) - def _load_model(self, model_id): + def load_model(self, model_id): """Downloads the cached model associated with a given model id Args: @@ -31,7 +35,18 @@ def _load_model(self, model_id): """ model_hash = self.sessionmaker().query(Model).get(model_id).model_hash model_store = self.model_storage_engine.get_store(model_hash) - return model_store.load() + if model_store.exists(): + return model_store.load() + + def delete_model(self, model_id): + """Deletes the cached model associated with a given model id + + Args: + model_id (int) The id of a given model in the database + """ + model_hash = self.sessionmaker().query(Model).get(model_id).model_hash + model_store = self.model_storage_engine.get_store(model_hash) + model_store.delete() def _write_to_db(self, model_id, as_of_date, entity_ids, predictions, labels, misc_db_parameters): """Writes given predictions to database @@ -84,7 +99,9 @@ def predict(self, model_id, matrix_store, misc_db_parameters): Returns: (numpy.Array) the generated prediction values """ - model = self._load_model(model_id) + model = self.load_model(model_id) + if not model: + raise ModelNotFoundError('Model id {} not found'.format(model_id)) labels = matrix_store.labels() as_of_date = matrix_store.metadata['end_time'] predictions = model.predict(matrix_store.matrix) diff --git a/triage/storage.py b/triage/storage.py index d5dffab76..75b208aa4 100644 --- a/triage/storage.py +++ b/triage/storage.py @@ -30,6 +30,9 @@ def write(self, obj): def load(self): return download_object(self.path) + def delete(self): + self.path.delete() + class FSStore(Store): def exists(self): @@ -43,6 +46,9 @@ def load(self): with open(self.path, 'rb') as f: return pickle.load(f) + def delete(self): + os.remove(self.path) + class MemoryStore(Store): store = None @@ -56,6 +62,9 @@ def write(self, obj): def load(self): return self.store + def delete(self): + self.store = None + class ModelStorageEngine(object): def __init__(self, project_path):