diff --git a/streaming/base/util.py b/streaming/base/util.py index 6c3437299..1804b373c 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -13,8 +13,8 @@ import tempfile import urllib.parse from collections import OrderedDict -from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from multiprocessing import Pool +from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from pathlib import Path from time import sleep, time from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload @@ -264,6 +264,7 @@ def _download_url(url_info): return f'Failed to download index.json: {src} to {dest}: {str(ex)}', ex return dest, None + def _merge_partition_indices(partition_indices): """Function to be executed by each process to merge a subset of partition indices.""" shards = [] @@ -279,12 +280,15 @@ def _merge_partition_indices(partition_indices): shards.extend(obj['shards']) return shards + def _parallel_merge_partitions(partitions, n_processes=4): """Divide the list of partitions among multiple processes and merge them in parallel.""" with Pool(processes=n_processes) as pool: # Split the list of partitions into N chunks where N is the number of processes chunk_size = len(partitions) // n_processes + (len(partitions) % n_processes > 0) - partition_chunks = [partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size)] + partition_chunks = [ + partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size) + ] # Process each chunk in parallel results = pool.map(_merge_partition_indices, partition_chunks) @@ -386,6 +390,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] if not keep_local: shutil.rmtree(cu.local, ignore_errors=True) + def _not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out. diff --git a/tests/test_util.py b/tests/test_util.py index ab024fb05..b05da8612 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -7,7 +7,7 @@ import time import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import List, Optional, Tuple, Union, Sequence +from typing import List, Optional, Sequence, Tuple, Union import pytest @@ -194,9 +194,9 @@ def test_format_remote_index_files(scheme: str): assert obj.scheme == scheme -@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3]) -@pytest.mark.parametrize('keep_local', [True]) # , False]) -@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://']) +@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3]) +@pytest.mark.parametrize('keep_local', [True]) # , False]) +@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://']) def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, index_file_urls_pattern: int, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: @@ -206,14 +206,12 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc 4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all 5. All URLs are str (remote) -> download all """ - from decimal import Decimal + import random + import string from pyspark.sql import SparkSession - from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType from streaming.base.converters import dataframeToMDS - import random - import string def not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out.""" @@ -229,12 +227,12 @@ def not_merged_index(index_file_path: str, out: str): def random_string(length=1000): """Generate a random string of fixed length.""" letters = string.ascii_letters + string.digits + string.punctuation + ' ' - return ''.join(random.choice(letters) for i in range(length)) + return ''.join(random.choice(letters) for _ in range(length)) # Generate a DataFrame with 10000 rows of random text num_rows = 100 - data = [(i, random_string(),random_string()) for i in range(num_rows)] - df = spark.createDataFrame(data, ["id", "name", "amount"]) + data = [(i, random_string(), random_string()) for i in range(num_rows)] + df = spark.createDataFrame(data, ['id', 'name', 'amount']) mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True} dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) @@ -343,27 +341,24 @@ def flaky_function(): def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str, str]]], out: Union[str, Tuple[str, str]], keep_local: bool = True, - download_timeout: int = 60, - merge = False) -> None: + download_timeout: int = 60) -> None: + import logging + import shutil import urllib.parse + from collections import OrderedDict + from pathlib import Path + + from streaming.base.format.index import get_index_basename from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader - from streaming.base.util import _not_merged_index, _format_remote_index_files - from streaming.base.format.index import get_index_basename - from collections import OrderedDict - import logging if not index_file_urls or not out: - logger.warning('Either index_file_urls or out are None. ' + - 'Need to specify both `index_file_urls` and `out`. ' + 'No index merged') return # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() - print('i am here 1.1') cu = CloudUploader.get(out, keep_local=True, exist_ok=True) - print('i am here 1.2') # Remove duplicates, and strip '/' from right if any index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) @@ -404,9 +399,6 @@ def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str partitions.append(dest) - if not merge: - return - # merge shards from all index files shards = [] for partition_index in partitions: