From 1e4536e6e06fd2f71be179d0cb4eaefe431fb13b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 12 Jan 2024 15:57:02 +0100 Subject: [PATCH] Move inversion types to analysis module --- src/ert/analysis/_es_update.py | 30 +++------------------ src/ert/config/analysis_config.py | 5 +++- src/ert/config/analysis_module.py | 22 ++++++++++++--- tests/unit_tests/analysis/test_es_update.py | 8 +++--- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 3cbd40f56ed..065264396bc 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -585,21 +585,12 @@ def analysis_ES( num_obs = len(observation_values) - inversion_types = {0: "exact", 1: "subspace", 2: "subspace", 3: "subspace"} - try: - inversion_type = inversion_types[module.ies_inversion] - except KeyError as e: - raise ErtAnalysisError( - f"Mismatched inversion type for: " - f"Specified: {module.ies_inversion}, with possible: {inversion_types}" - ) from e - smoother_es = ies.ESMDA( covariance=observation_errors**2, observations=observation_values, alpha=1, # The user is responsible for scaling observation covariance (esmda usage) seed=rng, - inversion=inversion_type, + inversion=module.inversion.name, ) truncation = module.enkf_truncation @@ -731,7 +722,7 @@ def analysis_ES( observation_errors=observation_errors, observation_values=observation_values, truncation=truncation, - inversion_type=inversion_type, + inversion_type=module.inversion.name, progress_callback=progress_callback, rng=rng, ) @@ -761,21 +752,6 @@ def analysis_IES( # This is needed for the SIES library masking_of_initial_parameters = ens_mask[initial_mask] - # Map paper (current in ERT) inversion-types to SIES inversion-types - inversion_types = { - 0: "direct", - 1: "subspace_exact", - 2: "subspace_projected", - 3: "subspace_projected", - } - try: - inversion_type = inversion_types[analysis_config.ies_inversion] - except KeyError as e: - raise ErtAnalysisError( - f"Mismatched inversion type for: " - f"Specified: {analysis_config.ies_inversion}, with possible: {inversion_types}" - ) from e - # It is not the iterations relating to IES or ESMDA. # It is related to functionality for turning on/off groups of parameters and observations. for update_step in update_config: @@ -824,7 +800,7 @@ def analysis_IES( covariance=observation_errors**2, observations=observation_values, seed=rng, - inversion=inversion_type, + inversion=analysis_config.inversion.name, truncation=analysis_config.enkf_truncation, ) diff --git a/src/ert/config/analysis_config.py b/src/ert/config/analysis_config.py index 31494062627..78c27880b3e 100644 --- a/src/ert/config/analysis_config.py +++ b/src/ert/config/analysis_config.py @@ -60,8 +60,11 @@ def __init__( if var_name == "ENKF_FORCE_NCOMP": continue if var_name == "INVERSION": - value = str(inversion_str_map[value]) + value = inversion_str_map[value] var_name = "IES_INVERSION" + if var_name == "IES_INVERSION": + value = int(value) + var_name = "inversion" key = var_name.lower() options[module_name][key] = value try: diff --git a/src/ert/config/analysis_module.py b/src/ert/config/analysis_module.py index c7e50f541f2..3bec8281526 100644 --- a/src/ert/config/analysis_module.py +++ b/src/ert/config/analysis_module.py @@ -1,5 +1,6 @@ import logging import math +from enum import IntEnum from typing import TYPE_CHECKING, Optional, Type, TypedDict, Union from pydantic import BaseModel, Extra, Field @@ -23,14 +24,10 @@ class VariableInfo(TypedDict): DEFAULT_IES_MIN_STEPLENGTH = 0.30 DEFAULT_IES_DEC_STEPLENGTH = 2.50 DEFAULT_ENKF_TRUNCATION = 0.98 -DEFAULT_IES_INVERSION = 0 DEFAULT_LOCALIZATION = False class BaseSettings(BaseModel): - ies_inversion: Annotated[ - int, Field(ge=0, le=3, title="Inversion algorithm") - ] = DEFAULT_IES_INVERSION enkf_truncation: Annotated[ float, Field(gt=0.0, le=1.0, title="Singular value truncation"), @@ -41,7 +38,21 @@ class Config: validate_assignment = True +class InversionTypeES(IntEnum): + exact = 0 + subspace = 1 + + +class InversionTypeIES(IntEnum): + direct = 0 + subspace_exact = 1 + subspace_projected = 2 + + class ESSettings(BaseSettings): + inversion: Annotated[ + InversionTypeES, Field(title="Inversion algorithm") + ] = InversionTypeES.exact localization: Annotated[bool, Field(title="Adaptive localization")] = False localization_correlation_threshold: Annotated[ Optional[float], @@ -69,6 +80,9 @@ class IESSettings(BaseSettings): """A good start is max steplength of 0.6, min steplength of 0.3, and decline of 2.5", A steplength of 1.0 and one iteration results in ES update""" + inversion: Annotated[ + InversionTypeIES, Field(title="Inversion algorithm") + ] = InversionTypeIES.subspace_exact ies_max_steplength: Annotated[ float, Field(ge=0.1, le=1.0, title="Gauss–Newton maximum steplength"), diff --git a/tests/unit_tests/analysis/test_es_update.py b/tests/unit_tests/analysis/test_es_update.py index e6da9b7900a..8050c9faa06 100644 --- a/tests/unit_tests/analysis/test_es_update.py +++ b/tests/unit_tests/analysis/test_es_update.py @@ -101,7 +101,7 @@ def test_update_report( ert_config.ensemble_config.parameters, ), UpdateSettings(misfit_preprocess=misfit_preprocess), - ESSettings(ies_inversion=1), + ESSettings(inversion=1), log_path=Path("update_log"), ) log_file = Path(ert_config.analysis_config.log_path) / "id.txt" @@ -237,7 +237,7 @@ def test_update_snapshot( run_id="id", update_config=update_configuration, update_settings=UpdateSettings(), - analysis_config=IESSettings(ies_inversion=1), + analysis_config=IESSettings(inversion=1), sies_step_length=sies_step_length, initial_mask=initial_mask, rng=rng, @@ -249,7 +249,7 @@ def test_update_snapshot( "id", update_configuration, UpdateSettings(), - ESSettings(ies_inversion=1), + ESSettings(inversion=1), rng=rng, ) @@ -357,7 +357,7 @@ def test_localization( "id", update_config, UpdateSettings(), - ESSettings(ies_inversion=1), + ESSettings(inversion=1), rng=np.random.default_rng(42), log_path=Path("update_log"), )