Skip to content

Commit

Permalink
refactor is_valid_dataset to raise exceptions, separate out s3 bucket…
Browse files Browse the repository at this point in the history
… and client fixtures
  • Loading branch information
rxu17 committed Jun 24, 2024
1 parent d4f6a52 commit 13dc8f8
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 155 deletions.
47 changes: 18 additions & 29 deletions src/glue/jobs/compare_parquet_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def read_args() -> dict:
"cfn-bucket",
],
)

for arg in args:
validate_args(args[arg])
return args
Expand Down Expand Up @@ -617,32 +616,30 @@ def add_additional_msg_to_comparison_report(
return updated_comparison_report


def is_valid_dataset(dataset: pd.DataFrame, namespace: str) -> dict:
def check_for_valid_dataset(dataset: pd.DataFrame, namespace: str) -> None:
"""Checks whether the individual dataset is valid under the following criteria:
- no duplicated columns
- dataset is not empty (aka has columns)
- dataset is not empty
before it can go through the comparison
Args:
dataset (pd.DataFrame): dataset to be validated
namespace (str): namespace for the dataset
Returns:
dict: containing boolean of the validation result and string message
Raises:
ValueError: When dataset is empty (no columns, no rows or no rows, columns)
ValueError: When dataset has duplicated columns
"""
# Check that datasets have no emptiness, duplicated columns, or have columns in common
if len(dataset.columns) == 0:
msg = f"{namespace} dataset has no data. Comparison cannot continue."
return {"result": False, "msg": msg}
if dataset.empty:
raise ValueError(
f"The {namespace} dataset is empty. Comparison cannot continue."
)
elif get_duplicated_columns(dataset):
msg = (
raise ValueError(
f"{namespace} dataset has duplicated columns. Comparison cannot continue.\n"
f"Duplicated columns:{str(get_duplicated_columns(dataset))}"
)
return {"result": False, "msg": msg}
else:
msg = f"{namespace} dataset has been validated."
return {"result": True, "msg": msg}


def compare_datasets_by_data_type(
Expand All @@ -667,6 +664,9 @@ def compare_datasets_by_data_type(
s3_filesystem (fs.S3FileSystem): filesystem instantiated by aws credentials
data_type (str): data type to be compared for the given datasets
Raises:
ValueError: When the staging and main datasets have no columns in common
Returns:
dict:
compare_obj: the datacompy.Compare obj on the two datasets
Expand Down Expand Up @@ -697,23 +697,12 @@ def compare_datasets_by_data_type(
s3_filesystem=s3_filesystem,
)
# go through specific validation for each dataset prior to comparison
staging_is_valid_result = is_valid_dataset(staging_dataset, staging_namespace)
main_is_valid_result = is_valid_dataset(main_dataset, main_namespace)
if (
staging_is_valid_result["result"] == False
or main_is_valid_result["result"] == False
):
comparison_report = (
f"{staging_is_valid_result['msg']}\n{main_is_valid_result['msg']}"
)
compare = None
check_for_valid_dataset(staging_dataset, staging_namespace)
check_for_valid_dataset(main_dataset, main_namespace)

# check that they have columns in common to compare
elif not has_common_cols(staging_dataset, main_dataset):
comparison_report = (
f"{staging_namespace} dataset and {main_namespace} dataset have no columns in common."
f" Comparison cannot continue."
)
compare = None
if not has_common_cols(staging_dataset, main_dataset):
raise ValueError("Datasets have no common columns to merge on.")
else:
logger.info(
f"{staging_namespace} dataset memory usage:"
Expand Down
13 changes: 9 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,19 @@ def dataset_fixture(request):


@pytest.fixture
def mock_s3_bucket():
def mock_s3_environment(mock_s3_bucket):
"""This allows us to persist the bucket and s3 client
"""
with mock_s3():
s3 = boto3.client('s3', region_name='us-east-1')
bucket_name = 'test-bucket'
s3.create_bucket(Bucket=bucket_name)
yield s3, bucket_name
s3.create_bucket(Bucket=mock_s3_bucket)
yield s3


@pytest.fixture
def mock_s3_bucket():
bucket_name = 'test-bucket'
yield bucket_name


@pytest.fixture()
Expand Down
Loading

0 comments on commit 13dc8f8

Please sign in to comment.