Skip to content

Commit

Permalink
added patch for test
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam6862 committed Apr 24, 2024
1 parent 6e19f06 commit 3bf5584
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,6 @@ examples/on-chain_mnist/cairo/lofi_mnst_2
examples/on-chain_mnist/cairo/soft
examples/on-chain_mnist/cairo/mnist_sierra
examples/on-chain_mnist/contracts/out

# cache files
tmp
2 changes: 1 addition & 1 deletion giza_actions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir"))
self.session = self._set_session()
if output_path is not None:
self._output_path = output_path
Expand All @@ -94,7 +95,6 @@ def __init__(
f"{self.model_id}_{self.version_id}_{self.model.name}",
)
self._download_model()
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir"))

def _get_endpoint_id(self):
"""
Expand Down
16 changes: 16 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ def test_predict_success_with_file(*args):
assert req_id == "123"


@patch("giza_actions.model.GizaModel._get_credentials")
@patch("giza_actions.model.GizaModel._get_model", return_value=Model(id=50))
@patch(
"giza_actions.model.GizaModel._get_version",
return_value=Version(
version=2,
framework="CAIRO",
size=1,
status="COMPLETED",
created_date="2022-01-01T00:00:00Z",
last_update="2022-01-01T00:00:00Z",
),
)
@patch("giza_actions.model.GizaModel._set_session")
@patch("giza_actions.model.GizaModel._get_output_dtype")
@patch("giza_actions.model.GizaModel._retrieve_uri")
Expand All @@ -119,5 +132,8 @@ def test_cache_implementation(*args):
assert cache_size_after_first_call == cache_size_after_second_call

result3 = model._get_output_dtype()
cache_size_after_third_call = len(model._cache)
result4 = model._get_output_dtype()
cache_size_after_fourth_call = len(model._cache)
assert result3 == result4
assert cache_size_after_third_call == cache_size_after_fourth_call

0 comments on commit 3bf5584

Please sign in to comment.