Skip to content

Commit

Permalink
Convert to dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Aug 20, 2024
1 parent 9d9dc7c commit e30c33b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 37 deletions.
43 changes: 12 additions & 31 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from collections import Counter
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from typing import (
Any,
Expand Down Expand Up @@ -72,23 +72,20 @@ def all_dates(self) -> List[datetime]:
return [self.start_date] + list(self.dates)


@dataclass
class EnsembleConfig:
def __init__(
self,
grid_file: Optional[str] = None,
response_configs: Optional[Dict[str, ResponseConfig]] = None,
parameter_configs: Optional[Dict[str, ParameterConfig]] = None,
refcase: Optional[Refcase] = None,
) -> None:
parameter_configs = parameter_configs if parameter_configs is not None else {}
response_configs = response_configs if response_configs is not None else {}
grid_file: Optional[str] = None
response_configs: Optional[Dict[str, ResponseConfig]] = field(default_factory=dict)
parameter_configs: Optional[Dict[str, ParameterConfig]] = field(
default_factory=dict
)
refcase: Optional[Refcase] = None

def __post_init__(self):
self._check_for_duplicate_names(
list(parameter_configs.values()), list(response_configs.values())
list(self.parameter_configs.values()), list(self.response_configs.values())
)
self.parameter_configs = parameter_configs
self.response_configs = response_configs
self._grid_file = _get_abs_path(grid_file)
self.refcase = refcase
self.grid_file = _get_abs_path(self.grid_file)

@staticmethod
def _check_for_duplicate_names(
Expand Down Expand Up @@ -225,10 +222,6 @@ def get_keylist_gen_kw(self) -> List[str]:
if isinstance(val, GenKwConfig)
]

@property
def grid_file(self) -> Optional[str]:
return self._grid_file

@property
def parameters(self) -> List[str]:
return list(self.parameter_configs)
Expand All @@ -244,18 +237,6 @@ def keys(self) -> List[str]:
def __contains__(self, key: str) -> bool:
return key in self.keys

def __eq__(self, other: object) -> bool:
if not isinstance(other, EnsembleConfig):
return False

return (
self.keys == other.keys
and self._grid_file == other._grid_file
and self.parameter_configs == other.parameter_configs
and self.response_configs == other.response_configs
and self.refcase == other.refcase
)

@property
def parameter_configuration(self) -> List[ParameterConfig]:
return list(self.parameter_configs.values())
Expand Down
4 changes: 2 additions & 2 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class ErtConfig:
user_config_file: str = "no_config"
config_path: str = field(init=False)
observation_config: Optional[
Dict[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = None
List[Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]]
] = field(default_factory=list)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ErtConfig):
Expand Down
9 changes: 6 additions & 3 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -35,10 +36,12 @@ def history_key(key: str) -> str:
return ":".join([keyword + "H"] + rest)


@dataclass
class EnkfObs:
def __init__(self, obs_vectors: Dict[str, ObsVector], obs_time: List[datetime]):
self.obs_vectors = obs_vectors
self.obs_time = obs_time
obs_vectors: Dict[str, ObsVector]
obs_time: List[datetime]

def __post_init__(self):
self.datasets: Dict[str, xr.Dataset] = {
name: obs.to_dataset([]) for name, obs in sorted(self.obs_vectors.items())
}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,7 +1558,7 @@ def test_that_multiple_errors_are_shown_when_validating_observation_config():
continue
print(line, end="")

with pytest.raises(ObservationConfigError) as err:
with pytest.raises(ConfigValidationError) as err:
_ = ErtConfig.from_file("snake_oil.ert")

expected_errors = [
Expand Down

0 comments on commit e30c33b

Please sign in to comment.