From 3bf5584c778e8d3cf530f07a176111b938dc2606 Mon Sep 17 00:00:00 2001 From: shivam6862 Date: Wed, 24 Apr 2024 15:04:12 +0530 Subject: [PATCH] added patch for test --- .gitignore | 3 +++ giza_actions/model.py | 2 +- tests/test_model.py | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index da9e4e0..b36b664 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,6 @@ examples/on-chain_mnist/cairo/lofi_mnst_2 examples/on-chain_mnist/cairo/soft examples/on-chain_mnist/cairo/mnist_sierra examples/on-chain_mnist/contracts/out + +# cache files +tmp \ No newline at end of file diff --git a/giza_actions/model.py b/giza_actions/model.py index b051921..6e47a76 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -85,6 +85,7 @@ def __init__( self.framework = self.version.framework self.uri = self._retrieve_uri() self.endpoint_id = self._get_endpoint_id() + self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) self.session = self._set_session() if output_path is not None: self._output_path = output_path @@ -94,7 +95,6 @@ def __init__( f"{self.model_id}_{self.version_id}_{self.model.name}", ) self._download_model() - self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) def _get_endpoint_id(self): """ diff --git a/tests/test_model.py b/tests/test_model.py index 6bfe8d7..3be8137 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -104,6 +104,19 @@ def test_predict_success_with_file(*args): assert req_id == "123" +@patch("giza_actions.model.GizaModel._get_credentials") +@patch("giza_actions.model.GizaModel._get_model", return_value=Model(id=50)) +@patch( + "giza_actions.model.GizaModel._get_version", + return_value=Version( + version=2, + framework="CAIRO", + size=1, + status="COMPLETED", + created_date="2022-01-01T00:00:00Z", + last_update="2022-01-01T00:00:00Z", + ), +) @patch("giza_actions.model.GizaModel._set_session") @patch("giza_actions.model.GizaModel._get_output_dtype") @patch("giza_actions.model.GizaModel._retrieve_uri") @@ -119,5 +132,8 @@ def test_cache_implementation(*args): assert cache_size_after_first_call == cache_size_after_second_call result3 = model._get_output_dtype() + cache_size_after_third_call = len(model._cache) result4 = model._get_output_dtype() + cache_size_after_fourth_call = len(model._cache) assert result3 == result4 + assert cache_size_after_third_call == cache_size_after_fourth_call