diff --git a/src/ert/config/ensemble_config.py b/src/ert/config/ensemble_config.py index edcae8e6f13..2647d4bec59 100644 --- a/src/ert/config/ensemble_config.py +++ b/src/ert/config/ensemble_config.py @@ -3,7 +3,7 @@ import logging import os from collections import Counter -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from typing import ( Any, @@ -72,23 +72,20 @@ def all_dates(self) -> List[datetime]: return [self.start_date] + list(self.dates) +@dataclass class EnsembleConfig: - def __init__( - self, - grid_file: Optional[str] = None, - response_configs: Optional[Dict[str, ResponseConfig]] = None, - parameter_configs: Optional[Dict[str, ParameterConfig]] = None, - refcase: Optional[Refcase] = None, - ) -> None: - parameter_configs = parameter_configs if parameter_configs is not None else {} - response_configs = response_configs if response_configs is not None else {} + grid_file: Optional[str] = None + response_configs: Optional[Dict[str, ResponseConfig]] = field(default_factory=dict) + parameter_configs: Optional[Dict[str, ParameterConfig]] = field( + default_factory=dict + ) + refcase: Optional[Refcase] = None + + def __post_init__(self): self._check_for_duplicate_names( - list(parameter_configs.values()), list(response_configs.values()) + list(self.parameter_configs.values()), list(self.response_configs.values()) ) - self.parameter_configs = parameter_configs - self.response_configs = response_configs - self._grid_file = _get_abs_path(grid_file) - self.refcase = refcase + self.grid_file = _get_abs_path(self.grid_file) @staticmethod def _check_for_duplicate_names( @@ -225,10 +222,6 @@ def get_keylist_gen_kw(self) -> List[str]: if isinstance(val, GenKwConfig) ] - @property - def grid_file(self) -> Optional[str]: - return self._grid_file - @property def parameters(self) -> List[str]: return list(self.parameter_configs) @@ -244,18 +237,6 @@ def keys(self) -> List[str]: def __contains__(self, key: str) -> bool: return key in self.keys - def __eq__(self, other: object) -> bool: - if not isinstance(other, EnsembleConfig): - return False - - return ( - self.keys == other.keys - and self._grid_file == other._grid_file - and self.parameter_configs == other.parameter_configs - and self.response_configs == other.response_configs - and self.refcase == other.refcase - ) - @property def parameter_configuration(self) -> List[ParameterConfig]: return list(self.parameter_configs.values()) diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index 5998b404789..fdae03ab607 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -105,8 +105,8 @@ class ErtConfig: user_config_file: str = "no_config" config_path: str = field(init=False) observation_config: Optional[ - Dict[str, Union[HistoryValues, SummaryValues, GenObsValues]] - ] = None + List[Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]] + ] = field(default_factory=list) def __eq__(self, other: object) -> bool: if not isinstance(other, ErtConfig): diff --git a/src/ert/config/observations.py b/src/ert/config/observations.py index 44e42184b81..fda323bb561 100644 --- a/src/ert/config/observations.py +++ b/src/ert/config/observations.py @@ -1,4 +1,5 @@ import os +from dataclasses import dataclass from datetime import datetime, timedelta from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union @@ -35,10 +36,12 @@ def history_key(key: str) -> str: return ":".join([keyword + "H"] + rest) +@dataclass class EnkfObs: - def __init__(self, obs_vectors: Dict[str, ObsVector], obs_time: List[datetime]): - self.obs_vectors = obs_vectors - self.obs_time = obs_time + obs_vectors: Dict[str, ObsVector] + obs_time: List[datetime] + + def __post_init__(self): self.datasets: Dict[str, xr.Dataset] = { name: obs.to_dataset([]) for name, obs in sorted(self.obs_vectors.items()) } diff --git a/tests/unit_tests/config/test_ert_config.py b/tests/unit_tests/config/test_ert_config.py index 9c4eb565dc6..40083dd5fcc 100644 --- a/tests/unit_tests/config/test_ert_config.py +++ b/tests/unit_tests/config/test_ert_config.py @@ -1558,7 +1558,7 @@ def test_that_multiple_errors_are_shown_when_validating_observation_config(): continue print(line, end="") - with pytest.raises(ObservationConfigError) as err: + with pytest.raises(ConfigValidationError) as err: _ = ErtConfig.from_file("snake_oil.ert") expected_errors = [