Skip to content

Commit

Permalink
Format type hinting to 3.10+ (IAMconsortium#433)
Browse files Browse the repository at this point in the history
* Format type hinting to 3.10+

* Change test docstring

Co-authored-by: Daniel Huppmann <dh@dergelbesalon.at>

---------

Co-authored-by: Daniel Huppmann <dh@dergelbesalon.at>
  • Loading branch information
dc-almeida and danielhuppmann authored Nov 27, 2024
1 parent f210213 commit a32eca9
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 125 deletions.
23 changes: 11 additions & 12 deletions nomenclature/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from typing import List, Optional
import importlib.util
import sys

Expand Down Expand Up @@ -55,10 +54,10 @@ def cli_valid_yaml(path: Path):
def cli_valid_project(
path: Path,
definitions: str,
mappings: Optional[str],
required_data: Optional[str],
validate_data: Optional[str],
dimensions: Optional[List[str]],
mappings: str | None,
required_data: str | None,
validate_data: str | None,
dimensions: list[str] | None,
):
"""Assert that `path` is a valid project nomenclature
Expand All @@ -74,7 +73,7 @@ def cli_valid_project(
Name of folder for 'required data' criteria, default to "required_data"
validate_data: str, optional
Name of folder for data validation criteria, default to "validate_data"
dimensions : List[str], optional
dimensions : list[str], optional
Dimensions to be checked, defaults to all sub-folders of `definitions`
Example
Expand Down Expand Up @@ -125,8 +124,8 @@ def check_region_aggregation(
workflow_directory: Path,
definitions: str,
mappings: str,
processed_data: Optional[Path],
differences: Optional[Path],
processed_data: Path | None,
differences: Path | None,
):
"""Perform region processing and compare aggregated and original data
Expand All @@ -141,10 +140,10 @@ def check_region_aggregation(
Definitions folder inside workflow_directory, by default "definitions"
mappings : str
Model mapping folder inside workflow_directory, by default "mappings"
processed_data : Optional[Path]
processed_data : Path, optional
If given, exports the results from region processing to a file called
`processed_data`, by default "results.xlsx"
differences : Optional[Path]
differences : Path, optional
If given, exports the differences between aggregated and model native data to a
file called `differences`, by default None
Expand Down Expand Up @@ -295,7 +294,7 @@ def cli_run_workflow(
multiple=True,
default=None,
)
def cli_validate_scenarios(input_file: Path, definitions: Path, dimensions: List[str]):
def cli_validate_scenarios(input_file: Path, definitions: Path, dimensions: list[str]):
"""Validate a scenario file against the codelists of a project
Example
Expand All @@ -312,7 +311,7 @@ def cli_validate_scenarios(input_file: Path, definitions: Path, dimensions: List
Input data file, must be IAMC format, .xlsx or .csv
definitions : Path
Definitions folder with codelists, by default "definitions"
dimensions : List[str], optional
dimensions : list[str], optional
Dimensions to be checked, defaults to all sub-folders of `definitions`
Raises
Expand Down
38 changes: 19 additions & 19 deletions nomenclature/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from keyword import iskeyword
from pathlib import Path
from typing import Any, Dict, List, Set, Union, Optional
from typing import Any
from pydantic import (
field_validator,
field_serializer,
Expand All @@ -24,8 +24,8 @@ class Code(BaseModel):

name: str
description: str | None = None
file: Union[str, Path] | None = None
extra_attributes: Dict[str, Any] = {}
file: str | Path | None = None
extra_attributes: dict[str, Any] = {}
repository: str | None = None

def __eq__(self, other) -> bool:
Expand All @@ -34,8 +34,8 @@ def __eq__(self, other) -> bool:
@field_validator("extra_attributes")
@classmethod
def check_attribute_names(
cls, v: Dict[str, Any], info: ValidationInfo
) -> Dict[str, Any]:
cls, v: dict[str, Any], info: ValidationInfo
) -> dict[str, Any]:
# Check that attributes only contains keys which are valid identifiers
if illegal_keys := [
key for key in v.keys() if not key.isidentifier() or iskeyword(key)
Expand Down Expand Up @@ -79,7 +79,7 @@ def from_dict(cls, mapping) -> "Code":
)

@classmethod
def named_attributes(cls) -> Set[str]:
def named_attributes(cls) -> set[str]:
return {a for a in cls.model_fields if a != "extra_attributes"}

@property
Expand Down Expand Up @@ -181,18 +181,18 @@ def __setattr__(self, name, value):


class VariableCode(Code):
unit: Union[str, List[str]] = Field(...)
unit: str | list[str] = Field(...)
tier: int | str | None = None
weight: str | None = None
region_aggregation: List[Dict[str, Dict]] | None = Field(
region_aggregation: list[dict[str, dict]] | None = Field(
default=None, alias="region-aggregation"
)
skip_region_aggregation: bool | None = Field(
default=False, alias="skip-region-aggregation"
)
method: str | None = None
check_aggregate: bool | None = Field(default=False, alias="check-aggregate")
components: Union[List[str], Dict[str, list[str]]] | None = None
components: list[str] | dict[str, list[str]] | None = None
drop_negative_weights: bool | None = None
model_config = ConfigDict(populate_by_name=True)

Expand Down Expand Up @@ -225,17 +225,17 @@ def convert_str_to_none_for_writing(self, v):
return v if v != "" else None

@property
def units(self) -> List[str]:
def units(self) -> list[str]:
return self.unit if isinstance(self.unit, list) else [self.unit]

@classmethod
def named_attributes(cls) -> Set[str]:
def named_attributes(cls) -> set[str]:
return (
super().named_attributes().union(f.alias for f in cls.model_fields.values())
)

@property
def pyam_agg_kwargs(self) -> Dict[str, Any]:
def pyam_agg_kwargs(self) -> dict[str, Any]:
# return a dict of all not None pyam aggregation properties
return {
field: getattr(self, field)
Expand All @@ -249,7 +249,7 @@ def pyam_agg_kwargs(self) -> Dict[str, Any]:
}

@property
def agg_kwargs(self) -> Dict[str, Any]:
def agg_kwargs(self) -> dict[str, Any]:
return (
{**self.pyam_agg_kwargs, **{"region_aggregation": self.region_aggregation}}
if self.region_aggregation is not None
Expand All @@ -274,11 +274,11 @@ class RegionCode(Code):
"""

hierarchy: str = None
countries: Optional[List[str]] = None
iso3_codes: Optional[Union[List[str], str]] = None
countries: list[str] | None = None
iso3_codes: list[str] | str | None = None

@field_validator("countries", mode="before")
def check_countries(cls, v: List[str], info: ValidationInfo) -> List[str]:
def check_countries(cls, v: list[str], info: ValidationInfo) -> list[str]:
"""Verifies that each country name is defined in `nomenclature.countries`."""
v = to_list(v)
if invalid_country_names := set(v) - set(countries.names):
Expand All @@ -291,7 +291,7 @@ def check_countries(cls, v: List[str], info: ValidationInfo) -> List[str]:
return v

@field_validator("iso3_codes")
def check_iso3_codes(cls, v: List[str], info: ValidationInfo) -> List[str]:
def check_iso3_codes(cls, v: list[str], info: ValidationInfo) -> list[str]:
"""Verifies that each ISO3 code is valid according to pycountry library."""
errors = ErrorCollector()
if invalid_iso3_codes := [
Expand All @@ -315,9 +315,9 @@ class MetaCode(Code):
Attributes
----------
allowed_values : Optional(list[any])
allowed_values : list[Any], optional
An optional list of allowed values
"""

allowed_values: List[Any] | None = None
allowed_values: list[Any] | None = None
52 changes: 26 additions & 26 deletions nomenclature/codelist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from pathlib import Path
from textwrap import indent
from typing import ClassVar, Dict, List
from typing import ClassVar

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -34,7 +34,7 @@ class CodeList(BaseModel):
"""

name: str
mapping: Dict[str, Code] = {}
mapping: dict[str, Code] = {}

# class variable
validation_schema: ClassVar[str] = "generic"
Expand All @@ -46,8 +46,8 @@ def __eq__(self, other):
@field_validator("mapping")
@classmethod
def check_stray_tag(
cls, v: Dict[str, Code], info: ValidationInfo
) -> Dict[str, Code]:
cls, v: dict[str, Code], info: ValidationInfo
) -> dict[str, Code]:
"""Check that no stray tags are left in codes after tag replacement"""
forbidden = ["{", "}"]

Expand Down Expand Up @@ -75,8 +75,8 @@ def _check_string(value):
@field_validator("mapping")
@classmethod
def check_end_whitespace(
cls, v: Dict[str, Code], info: ValidationInfo
) -> Dict[str, Code]:
cls, v: dict[str, Code], info: ValidationInfo
) -> dict[str, Code]:
"""Check that no code ends with a whitespace"""
for code in v:
if code.endswith(" "):
Expand Down Expand Up @@ -127,7 +127,7 @@ def validate_data(
return False
return True

def validate_items(self, items: List[str]) -> List[str]:
def validate_items(self, items: list[str]) -> list[str]:
"""Validate that a list of items are valid codes
Returns
Expand All @@ -140,9 +140,9 @@ def validate_items(self, items: List[str]) -> List[str]:

@classmethod
def replace_tags(
cls, code_list: List[Code], tag_name: str, tags: List[Code]
) -> List[Code]:
_code_list: List[Code] = []
cls, code_list: list[Code], tag_name: str, tags: list[Code]
) -> list[Code]:
_code_list: list[Code] = []

for code in code_list:
if "{" + tag_name + "}" in code.name:
Expand All @@ -155,15 +155,15 @@ def replace_tags(
@classmethod
def _parse_and_replace_tags(
cls,
code_list: List[Code],
code_list: list[Code],
path: Path,
file_glob_pattern: str = "**/*",
) -> List[Code]:
) -> list[Code]:
"""Cast, validate and replace tags into list of codes for one dimension
Parameters
----------
code_list : List[Code]
code_list : list[Code]
List of Code to modify
path : :class:`pathlib.Path` or path-like
Directory with the codelist files
Expand All @@ -173,10 +173,10 @@ def _parse_and_replace_tags(
Returns
-------
Dict[str, Code] :class: `nomenclature.Code`
dict[str, Code] :class: `nomenclature.Code`
"""
tag_dict: Dict[str, List[Code]] = {}
tag_dict: dict[str, list[Code]] = {}

for yaml_file in (
f
Expand Down Expand Up @@ -240,7 +240,7 @@ def from_directory(
)
code_list.extend(repo.filter_list_of_codes(repository_code_list))
errors = ErrorCollector()
mapping: Dict[str, Code] = {}
mapping: dict[str, Code] = {}
for code in code_list:
if code.name in mapping:
errors.append(
Expand Down Expand Up @@ -301,7 +301,7 @@ def _parse_codelist_dir(
file_glob_pattern: str = "**/*",
repository: str | None = None,
):
code_list: List[Code] = []
code_list: list[Code] = []
for yaml_file in (
f
for f in path.glob(file_glob_pattern)
Expand Down Expand Up @@ -457,7 +457,7 @@ def to_excel(
with pd.ExcelWriter(excel_writer, **kwargs) as writer:
write_sheet(writer, sheet_name, self.to_pandas(sort_by_code))

def codelist_repr(self, json_serialized=False) -> Dict:
def codelist_repr(self, json_serialized=False) -> dict:
"""Cast a CodeList into corresponding dictionary"""

nice_dict = {}
Expand Down Expand Up @@ -590,7 +590,7 @@ def check_weight_in_vars(cls, v):
)
return v

def vars_default_args(self, variables: List[str]) -> List[VariableCode]:
def vars_default_args(self, variables: list[str]) -> list[VariableCode]:
"""return subset of variables which does not feature any special pyam
aggregation arguments and where skip_region_aggregation is False"""
return [
Expand All @@ -599,7 +599,7 @@ def vars_default_args(self, variables: List[str]) -> List[VariableCode]:
if not self[var].agg_kwargs and not self[var].skip_region_aggregation
]

def vars_kwargs(self, variables: List[str]) -> List[VariableCode]:
def vars_kwargs(self, variables: list[str]) -> list[VariableCode]:
# return subset of variables which features special pyam aggregation arguments
# and where skip_region_aggregation is False
return [
Expand Down Expand Up @@ -713,7 +713,7 @@ def from_directory(
"""

code_list: List[RegionCode] = []
code_list: list[RegionCode] = []

# initializing from general configuration
# adding all countries
Expand Down Expand Up @@ -763,7 +763,7 @@ def from_directory(
)

# translate to mapping
mapping: Dict[str, RegionCode] = {}
mapping: dict[str, RegionCode] = {}

errors = ErrorCollector()
for code in code_list:
Expand All @@ -783,12 +783,12 @@ def from_directory(
return cls(name=name, mapping=mapping)

@property
def hierarchy(self) -> List[str]:
def hierarchy(self) -> list[str]:
"""Return the hierarchies defined in the RegionCodeList
Returns
-------
List[str]
list[str]
"""
return sorted(list({v.hierarchy for v in self.mapping.values()}))
Expand All @@ -799,9 +799,9 @@ def _parse_region_code_dir(
path: Path,
file_glob_pattern: str = "**/*",
repository: str | None = None,
) -> List[RegionCode]:
) -> list[RegionCode]:
""""""
code_list: List[RegionCode] = []
code_list: list[RegionCode] = []
for yaml_file in (
f
for f in path.glob(file_glob_pattern)
Expand Down
Loading

0 comments on commit a32eca9

Please sign in to comment.