diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index 453201b0..8d05260a 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -71,6 +71,7 @@ CONNECTION_TIMEOUT_JOB_START, CONNECTION_TIMEOUT_RESULT, AggregatorConfig, + get_backend_config, ) from openeo_aggregator.connection import ( BackendConnection, @@ -353,7 +354,6 @@ def __init__( self._memoizer = memoizer_from_config(config=config, namespace="Processing") self.backends.on_connections_change.add(self._memoizer.invalidate) self._catalog = catalog - self._stream_chunk_size = config.streaming_chunk_size def get_process_registry( self, api_version: Union[str, ComparableVersion] @@ -537,7 +537,7 @@ def evaluate(self, process_graph: dict, env: EvalEnv = None): _log.error(f"Failed to process synchronously on backend {con.id}: {e!r}", exc_info=True) raise OpenEOApiException(message=f"Failed to process synchronously on backend {con.id}: {e!r}") - return streaming_flask_response(backend_response, chunk_size=self._stream_chunk_size) + return streaming_flask_response(backend_response, chunk_size=get_backend_config().streaming_chunk_size) def preprocess_process_graph(self, process_graph: FlatPG, backend_id: str) -> dict: def preprocess(node: Any) -> Any: diff --git a/src/openeo_aggregator/config.py b/src/openeo_aggregator/config.py index f61a9218..415cd0ce 100644 --- a/src/openeo_aggregator/config.py +++ b/src/openeo_aggregator/config.py @@ -44,8 +44,6 @@ class AggregatorConfig(dict): # Dictionary mapping backend id to backend url aggregator_backends = dict_item() - streaming_chunk_size = dict_item(default=STREAM_CHUNK_SIZE_DEFAULT) - # TODO: add validation/normalization to make sure we have a real list of OidcProvider objects? configured_oidc_providers: List[OidcProvider] = dict_item(default=[]) auth_entitlement_check: Union[bool, dict] = dict_item(default=False) @@ -63,6 +61,9 @@ class AggregatorConfig(dict): # List of collection ids to cover with the aggregator (when None: support union of all upstream collections) collection_whitelist = dict_item(default=None) + # Just a config field for test purposes (while were stripping down this config class) + test_dummy = dict_item(default="alice") + @staticmethod def from_py_file(path: Union[str, Path]) -> 'AggregatorConfig': """Load config from Python file.""" @@ -134,6 +135,9 @@ class AggregatorBackendConfig(OpenEoBackendConfig): packages=["openeo", "openeo_driver", "openeo_aggregator"], ) + streaming_chunk_size: int = STREAM_CHUNK_SIZE_DEFAULT + + _config_getter = ConfigGetter(expected_class=AggregatorBackendConfig) diff --git a/src/openeo_aggregator/testing.py b/src/openeo_aggregator/testing.py index db134123..c9b04c74 100644 --- a/src/openeo_aggregator/testing.py +++ b/src/openeo_aggregator/testing.py @@ -1,19 +1,19 @@ -import collections import dataclasses import datetime import itertools import json import pathlib -import time from typing import Any, Dict, List, Optional, Tuple, Union from unittest import mock import kazoo import kazoo.exceptions +import openeo_driver.testing import pytest from openeo.util import rfc3339 import openeo_aggregator.about +import openeo_aggregator.config from openeo_aggregator.utils import Clock @@ -290,3 +290,32 @@ def processes(self, *args) -> dict: processes.append(process) return {"processes": processes, "links": []} + + +def config_overrides(**kwargs): + """ + *Only to be used in unit tests* + + `mock.patch` based mocker to override the config returned by `get_backend_config()` at run time + + Can be used as context manager + + >>> with config_overrides(id="foobar"): + ... ... + + in a fixture (as context manager): + + >>> @pytest.fixture + ... def custom_setup() + ... with config_overrides(id="foobar"): + ... yield + + or as test function decorator + + >>> @config_overrides(id="foobar") + ... def test_stuff(): + """ + return openeo_driver.testing.config_overrides( + config_getter=openeo_aggregator.config._config_getter, + **kwargs, + ) diff --git a/tests/test_config.py b/tests/test_config.py index 46f9eb33..acd6790a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -21,7 +21,7 @@ def _get_config_content(config_var_name: str = "config"): {config_var_name} = AggregatorConfig( config_source=__file__, aggregator_backends={{"b1": "https://b1.test"}}, - streaming_chunk_size=123 + test_dummy="bob", ) """ ) @@ -31,7 +31,7 @@ def test_config_defaults(): config = AggregatorConfig() with pytest.raises(KeyError): _ = config.aggregator_backends - assert config.streaming_chunk_size == STREAM_CHUNK_SIZE_DEFAULT + assert config.test_dummy == "alice" def test_config_aggregator_backends(): @@ -48,7 +48,7 @@ def test_config_from_py_file(tmp_path, config_var_name): config = AggregatorConfig.from_py_file(path) assert config.config_source == str(path) assert config.aggregator_backends == {"b1": "https://b1.test"} - assert config.streaming_chunk_size == 123 + assert config.test_dummy == "bob" def test_config_from_py_file_wrong_config_var_name(tmp_path): @@ -71,7 +71,7 @@ def test_get_config_py_file_path(tmp_path, convertor): config = get_config(convertor(config_path)) assert config.config_source == str(config_path) assert config.aggregator_backends == {"b1": "https://b1.test"} - assert config.streaming_chunk_size == 123 + assert config.test_dummy == "bob" def test_get_config_env_py_file(tmp_path): @@ -82,4 +82,4 @@ def test_get_config_env_py_file(tmp_path): config = get_config() assert config.config_source == str(path) assert config.aggregator_backends == {"b1": "https://b1.test"} - assert config.streaming_chunk_size == 123 + assert config.test_dummy == "bob" diff --git a/tests/test_views.py b/tests/test_views.py index 041b44b3..5618f224 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -33,7 +33,7 @@ STAC_PROPERTY_FEDERATION_BACKENDS, STAC_PROPERTY_PROVIDER_BACKEND, ) -from openeo_aggregator.testing import clock_mock +from openeo_aggregator.testing import clock_mock, config_overrides from .conftest import assert_dict_subset, get_api100, get_flask_app @@ -800,7 +800,6 @@ def test_result_basic_math_error(self, api100, requests_mock, backend1, backend2 @pytest.mark.parametrize(["chunk_size"], [(16,), (128,)]) def test_result_large_response_streaming(self, config, chunk_size, requests_mock, backend1, backend2): - config.streaming_chunk_size = chunk_size api100 = get_api100(get_flask_app(config)) def post_result(request: requests.Request, context): @@ -813,7 +812,9 @@ def post_result(request: requests.Request, context): api100.set_auth_bearer_token(token=TEST_USER_BEARER_TOKEN) pg = {"large": {"process_id": "large", "arguments": {}, "result": True}} request = {"process": {"process_graph": pg}} - res = api100.post("/result", json=request).assert_status_code(200) + + with config_overrides(streaming_chunk_size=chunk_size): + res = api100.post("/result", json=request).assert_status_code(200) assert res.response.is_streamed chunks = res.response.iter_encoded()