Skip to content

Commit

Permalink
improve column dtype in flatfile IO
Browse files Browse the repository at this point in the history
  • Loading branch information
rizac committed Dec 4, 2024
1 parent 7c77a20 commit 37739ac
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions egsim/smtk/flatfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
def read_flatfile(
filepath_or_buffer: Union[str, IOBase],
rename: dict[str, str] = None,
dtypes: dict[str, Union[str, list, ColumnDtype, pd.CategoricalDtype]] = None,
dtypes: dict[str, Union[str, list]] = None,
defaults: dict[str, Any] = None,
csv_sep: str = None,
**kwargs) -> pd.DataFrame:
**kwargs: dict[str, Any]) -> pd.DataFrame:
"""
Read a flatfile from either a comma-separated values (CSV) or HDF file,
returning the corresponding pandas DataFrame.
Expand All @@ -56,15 +56,14 @@ def read_flatfile(
for renaming columns to standard flatfile names, delegating all data types
check to the function without (see also dtypes and defaults for info)
:param dtypes: dict of file column names mapped to user-defined data types, to
check and cast data after the data is read. Standard flatfile columns do not
need to be present. If they are, the value here will overwrite the default dtype,
check and cast column data. Standard flatfile columns should not be present,
otherwise the value provided in this dict will overwrite the registered dtype,
if set. Columns in `dtypes` not present in the file will be ignored.
Dict values can be either 'int', 'bool', 'float', 'str', 'datetime', 'category'`,
list, pandas `CategoricalDtype`: the last three denote data that can
take only a limited amount of possible values and should be mostly used with
string data as it might save a lot of memory (with "category", pandas will infer
the possible values from the data. In this case, note that with CSV files each
category will be of type `str`).
list: 'category' and lists denote data that can take only a limited amount of
possible values and should be mostly used with string data for saving space
(with "category", pandas will infer the possible values from the data. In this
case, note that with CSV files each category will be of type `str`).
:param defaults: dict of file column names mapped to user-defined default to
replace missing values. Because 'int' and 'bool' columns do not support missing
values, with CSV files a default should be provided (e.g. 0 or False) to avoid
Expand Down Expand Up @@ -99,8 +98,29 @@ def read_flatfile(
else:
kwargs['sep'] = csv_sep

kwargs.setdefault('dtype', {})
kwargs['dtype'] |= dtypes or {}
# harmonize dtypes with only ColumnDtype enums or pd.,CategoricalDtype objects:
# also put in kwargs['dtype'] the associated dtypes compatible with `read_csv`:
kwargs['dtype'] = kwargs.get('dtype') or {}
dtypes_raw = dtypes or {}
dtypes = {}
for c, v in dtypes_raw.items():
if not isinstance(v, str):
try:
v = pd.CategoricalDtype(v)
assert get_dtype_of(v.categories) is not None
except (AssertionError, TypeError, ValueError):
raise ValueError(f'{c}: categories must be of the same type')
else:
try:
v = ColumnDtype[v]
except KeyError:
raise ValueError(f'{c}: invalid dtype {v}')
dtypes[c] = v
# ignore bool int and date-times, we will parse them later
if v in (ColumnDtype.bool, ColumnDtype.int, ColumnDtype.datetime):
continue
kwargs['dtype'][c] = v.name if isinstance(v, ColumnDtype) else v

try:
dfr = pd.read_csv(filepath_or_buffer, **kwargs)
except ValueError as exc:
Expand Down Expand Up @@ -160,7 +180,7 @@ def _read_csv_get_header(filepath_or_buffer: IOBase, sep=None, **kwargs) -> list

def validate_flatfile_dataframe(
dfr: pd.DataFrame,
extra_dtypes: dict[str, Union[str, ColumnDtype, pd.CategoricalDtype, list]] = None, # noqa
extra_dtypes: dict[str, Union[ColumnDtype, pd.CategoricalDtype]] = None, # noqa
extra_defaults: dict[str, Any] = None,
mixed_dtype_categorical='raise'):
"""Validate the flatfile dataframe checking data types, conflicting column names,
Expand All @@ -170,7 +190,8 @@ def validate_flatfile_dataframe(
:param dfr: the flatfile, as pandas DataFrame
:param extra_dtypes: dict of column names mapped to the desired data type.
Standard flatfile columns should not to be present (unless for some reason
their dtype must be overwritten)
their dtype must be overwritten). pd.CategoricalDtype categories must be
all the same type (this is supposed to have been checked beforehand)
:param extra_defaults: dict of column names mapped to the desired default value
to replace missing data. Standard flatfile columns do not need to be present
(unless for some reason their dtype must be overwritten)
Expand All @@ -190,15 +211,6 @@ def validate_flatfile_dataframe(
for col in dfr.columns:
if col in extra_dtypes:
xp_dtype = extra_dtypes[col]
if not isinstance(xp_dtype, ColumnDtype):
try:
xp_dtype = ColumnDtype[extra_dtypes[col]]
except KeyError:
try:
xp_dtype = pd.CategoricalDtype(xp_dtype)
except (TypeError, ValueError):
invalid_columns.append(col)
continue
else:
xp_dtype = FlatfileMetadata.get_dtype(col)
if xp_dtype == ColumnDtype.category:
Expand Down

0 comments on commit 37739ac

Please sign in to comment.