Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Dec 11, 2024
1 parent b21ea15 commit c551a82
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 24 deletions.
10 changes: 4 additions & 6 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import re
import shutil
from abc import abstractmethod
from copy import copy
from typing import Any, Literal, Mapping, Optional, no_type_check
from collections.abc import Mapping
from typing import Any, Literal, no_type_check

import pydantic
from pydantic import Field, field_validator
from pydantic import Field, constr, field_validator
from pydantic.dataclasses import dataclass
from pydantic_core.core_schema import ValidationInfo
from typing_extensions import Annotated

from .parsing import (
BaseModelWithContextSupport,
Expand All @@ -27,7 +26,7 @@

logger = logging.getLogger(__name__)

NonEmptyString = Annotated[str, pydantic.StringConstraints(min_length=1)]
NonEmptyString = constr(min_length=1)


def activate_script() -> str:
Expand Down Expand Up @@ -119,7 +118,6 @@ class LsfQueueOptions(QueueOptions):
lsf_queue: NonEmptyString | None = None
lsf_resource: str | None = None


@property
def driver_options(self) -> dict[str, Any]:
driver_dict = self.model_dump(exclude={"name", "submit_sleep", "max_running"})
Expand Down
31 changes: 20 additions & 11 deletions tests/ert/unit_tests/config/config_dict_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, get_args, get_origin
from warnings import filterwarnings

import hypothesis.strategies as st
Expand Down Expand Up @@ -142,35 +142,44 @@ def valid_queue_options(queue_system: str):
]


def has_base_type(
field_type, base_type: type[int] | bool | type[str] | type[float]
) -> bool:
if field_type is base_type:
return True
origin = get_origin(field_type)
if origin:
args = get_args(field_type)
if any(arg is base_type for arg in args):
return True
return any(has_base_type(arg, base_type) for arg in args)
return False


queue_options_by_type: dict[str, dict[str, list[str]]] = defaultdict(dict)
for system, options in queue_systems_and_options.items():
queue_options_by_type["string"][system.name] = [
name.upper()
for name, field in options.model_fields.items()
if ("String" in str(field.annotation) or "str" in str(field.annotation))
and "memory" not in name
if has_base_type(field.annotation, str) and "memory" not in name
]
queue_options_by_type["bool"][system.name] = [
name.upper()
for name, field in options.model_fields.items()
if "bool" in str(field.annotation)
if has_base_type(field.annotation, bool)
]
queue_options_by_type["posint"][system.name] = [
name.upper()
for name, field in options.model_fields.items()
if "PositiveInt" in str(field.annotation)
or "NonNegativeInt" in str(field.annotation)
if has_base_type(field.annotation, int)
]
queue_options_by_type["posfloat"][system.name] = [
name.upper()
for name, field in options.model_fields.items()
if "NonNegativeFloat" in str(field.annotation)
or "PositiveFloat" in str(field.annotation)
if has_base_type(field.annotation, float)
]
queue_options_by_type["memory"][system.name] = [
name.upper()
for name, field in options.model_fields.items()
if "memory" in str(field.annotation)
name.upper() for name in options.model_fields if "memory" in name
]


Expand Down
3 changes: 1 addition & 2 deletions tests/ert/unit_tests/config/test_queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
QueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
Expand Down Expand Up @@ -520,5 +519,5 @@ def test_default_activate_script_generation(expected, monkeypatch, venv):
monkeypatch.setenv("VIRTUAL_ENV", venv)
else:
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
options = QueueOptions(name="local")
options = LocalQueueOptions()
assert options.activate_script == expected
8 changes: 3 additions & 5 deletions tests/everest/test_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,11 @@ def test_that_max_runtime_errors_only_on_negative():


def test_that_invalid_queue_system_errors():
with pytest.raises(ValueError) as e:
with pytest.raises(
ValueError, match="does not match .*'local',.*'lsf',.*'slurm', .*'torque'"
):
EverestConfig.with_defaults(simulator={"queue_system": {"name": "docal"}})

assert has_error(
e.value, match="does not match .*'lsf', .*'local', .*'slurm', .*'torque'"
)


@pytest.mark.parametrize(
["cores", "expected_error"], [(0, False), (-1, True), (1, False)]
Expand Down

0 comments on commit c551a82

Please sign in to comment.