Skip to content

Commit

Permalink
Merge IAMconsortium#466 into branch
Browse files Browse the repository at this point in the history
  • Loading branch information
David Almeida committed Jan 29, 2025
2 parents 6884f47 + d400ac5 commit 90fd253
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 105 deletions.
183 changes: 86 additions & 97 deletions nomenclature/processor/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
AfterValidator,
BaseModel,
ConfigDict,
Field,
ValidationInfo,
field_validator,
model_validator,
validate_call,
field_serializer,
)
from pydantic.types import DirectoryPath, FilePath
from pydantic_core import PydanticCustomError
Expand Down Expand Up @@ -109,23 +111,11 @@ class RegionAggregationMapping(BaseModel):

model: list[str]
file: FilePath
native_regions: list[NativeRegion] | None = None
common_regions: list[CommonRegion] | None = None
exclude_regions: list[str] | None = None
native_regions: list[NativeRegion] = Field(default_factory=list)
common_regions: list[CommonRegion] = Field(default_factory=list)
exclude_regions: list[str] = Field(default_factory=list)

@model_validator(mode="before")
@classmethod
def check_no_additional_attributes(cls, v):
if illegal_additional_attributes := [
input_attribute
for input_attribute in v.keys()
if input_attribute not in cls.model_fields
]:
raise ValueError(
"Illegal attributes in 'RegionAggregationMapping': "
f"{illegal_additional_attributes} (file {v['file']})"
)
return v
model_config = ConfigDict(extra="forbid")

@field_validator("model", mode="before")
@classmethod
Expand Down Expand Up @@ -188,7 +178,7 @@ def check_native_or_common_regions(
cls, v: "RegionAggregationMapping"
) -> "RegionAggregationMapping":
# Check that we have at least one of the two: native and common regions
if v.native_regions is None and v.common_regions is None:
if not v.native_regions and not v.common_regions:
raise ValueError(
"At least one of 'native_regions' and 'common_regions' must be "
f"provided in {v.file}"
Expand All @@ -201,9 +191,7 @@ def check_illegal_renaming(
cls, v: "RegionAggregationMapping"
) -> "RegionAggregationMapping":
"""Check if any renaming overlaps with common regions"""
# Skip if only either native-regions or common-regions are specified
if v.native_regions is None or v.common_regions is None:
return v

native_region_names = {nr.target_native_region for nr in v.native_regions}
common_region_names = {cr.name for cr in v.common_regions}
overlap = list(native_region_names & common_region_names)
Expand Down Expand Up @@ -423,28 +411,35 @@ def check_unexpected_regions(self, df: IamDataFrame) -> None:
def __eq__(self, other: "RegionAggregationMapping") -> bool:
return self.model_dump(exclude={"file"}) == other.model_dump(exclude={"file"})

@field_serializer("model", when_used="json")
def serialize_model(self, model) -> str | list[str]:
return model[0] if len(model) == 1 else model

@field_serializer("native_regions", when_used="json")
def serialize_native_regions(self, native_regions) -> list:
return [
(
{native_region.name: native_region.rename}
if native_region.rename
else native_region.name
)
for native_region in native_regions
]

@field_serializer("common_regions", when_used="json")
def serialize_common_regions(self, common_regions) -> list:
return [
{common_region.name: common_region.constituent_regions}
for common_region in common_regions
]

def to_yaml(self, file) -> None:
dict_representation = {
"model": self.model[0] if len(self.model) == 1 else self.model
}
if self.native_regions:
dict_representation["native_regions"] = [
(
{native_region.name: native_region.rename}
if native_region.rename
else native_region.name
)
for native_region in self.native_regions
]
if self.common_regions:
dict_representation["common_regions"] = [
{common_region.name: common_region.constituent_regions}
for common_region in self.common_regions
]
if self.exclude_regions:
dict_representation["exclude_regions"] = self.exclude_regions
with open(file, "w", encoding="utf-8") as f:
yaml.dump(dict_representation, f, sort_keys=False)
yaml.dump(
self.model_dump(mode="json", exclude_defaults=True, exclude={"file"}),
f,
sort_keys=False,
)


def validate_with_definition(v: RegionAggregationMapping, info: ValidationInfo):
Expand Down Expand Up @@ -649,70 +644,64 @@ def _apply_region_processing(
# silence pyam's empty filter warnings
with adjust_log_level(logger="pyam", level="ERROR"):
# rename native regions
if self.mappings[model].native_regions is not None:
_df = model_df.filter(
region=self.mappings[model].model_native_region_names
_df = model_df.filter(region=self.mappings[model].model_native_region_names)
if not _df.empty:
_processed_data.append(
_df.rename(region=self.mappings[model].rename_mapping)._data
)
if not _df.empty:
_processed_data.append(
_df.rename(region=self.mappings[model].rename_mapping)._data
)

# aggregate common regions
if self.mappings[model].common_regions is not None:
for common_region in self.mappings[model].common_regions:
# if a common region is consists of a single native region, rename
if common_region.is_single_constituent_region:
_df = model_df.filter(
region=common_region.constituent_regions[0]
).rename(region=common_region.rename_dict)
if not _df.empty:
_processed_data.append(_df._data)
continue
for common_region in self.mappings[model].common_regions:
# if a common region is consists of a single native region, rename
if common_region.is_single_constituent_region:
_df = model_df.filter(
region=common_region.constituent_regions[0]
).rename(region=common_region.rename_dict)
if not _df.empty:
_processed_data.append(_df._data)
continue

# if there are multiple constituent regions, aggregate
regions = [common_region.name, common_region.constituent_regions]
# if there are multiple constituent regions, aggregate
regions = [common_region.name, common_region.constituent_regions]

# first, perform 'simple' aggregation (no arguments)
simple_vars = [
var
for var in self.variable_codelist.vars_default_args(
model_df.variable
)
]
_df = model_df.aggregate_region(
simple_vars,
*regions,
# first, perform 'simple' aggregation (no arguments)
simple_vars = [
var
for var in self.variable_codelist.vars_default_args(
model_df.variable
)
if _df is not None and not _df.empty:
_processed_data.append(_df._data)

# second, special weighted aggregation
for var in self.variable_codelist.vars_kwargs(model_df.variable):
if var.region_aggregation is None:
_df = _aggregate_region(
model_df,
var.name,
*regions,
**var.pyam_agg_kwargs,
)
if _df is not None and not _df.empty:
_processed_data.append(_df._data)
else:
for rename_var in var.region_aggregation:
for _rename, _kwargs in rename_var.items():
_df = _aggregate_region(
model_df,
var.name,
*regions,
**_kwargs,
]
_df = model_df.aggregate_region(
simple_vars,
*regions,
)
if _df is not None and not _df.empty:
_processed_data.append(_df._data)

# second, special weighted aggregation
for var in self.variable_codelist.vars_kwargs(model_df.variable):
if var.region_aggregation is None:
_df = _aggregate_region(
model_df,
var.name,
*regions,
**var.pyam_agg_kwargs,
)
if _df is not None and not _df.empty:
_processed_data.append(_df._data)
else:
for rename_var in var.region_aggregation:
for _rename, _kwargs in rename_var.items():
_df = _aggregate_region(
model_df,
var.name,
*regions,
**_kwargs,
)
if _df is not None and not _df.empty:
_processed_data.append(
_df.rename(variable={var.name: _rename})._data
)
if _df is not None and not _df.empty:
_processed_data.append(
_df.rename(
variable={var.name: _rename}
)._data
)

common_region_df = model_df.filter(
region=self.mappings[model].common_region_names,
Expand Down
21 changes: 13 additions & 8 deletions tests/test_region_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,14 @@ def test_mapping():
"constituent_regions": ["region_c"],
},
],
"exclude_regions": None,
"exclude_regions": [],
}
assert obs.model_dump() == exp


@pytest.mark.parametrize(
"file, error_msg_pattern",
[
(
"illegal_mapping_illegal_attribute.yaml",
"Illegal attributes in 'RegionAggregationMapping'",
),
(
"illegal_mapping_conflict_regions.yaml",
"Name collision in native and common regions.*common_region_1",
Expand Down Expand Up @@ -96,6 +92,15 @@ def test_illegal_mappings(file, error_msg_pattern):
RegionAggregationMapping.from_file(TEST_FOLDER_REGION_AGGREGATION / file)


def test_illegal_additional_attribute():
with pytest.raises(
pydantic.ValidationError, match="Extra inputs are not permitted"
):
RegionAggregationMapping.from_file(
TEST_FOLDER_REGION_AGGREGATION / "illegal_mapping_illegal_attribute.yaml"
)


def test_mapping_parsing_error():
with pytest.raises(ValueError, match="string indices must be integers"):
RegionAggregationMapping.from_file(
Expand Down Expand Up @@ -123,15 +128,15 @@ def test_region_processor_working(region_processor_path, simple_definition):
"native_regions": [
{"name": "World", "rename": None},
],
"common_regions": None,
"exclude_regions": None,
"common_regions": [],
"exclude_regions": [],
},
{
"model": ["model_b"],
"file": (
TEST_FOLDER_REGION_PROCESSING / "regionprocessor_working/mapping_2.yaml"
).relative_to(Path.cwd()),
"native_regions": None,
"native_regions": [],
"common_regions": [
{
"name": "World",
Expand Down

0 comments on commit 90fd253

Please sign in to comment.