Skip to content

Commit

Permalink
fix lints
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Feb 5, 2024
1 parent 70d8e8f commit feee52d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 26 deletions.
9 changes: 7 additions & 2 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import tempfile
import urllib.parse
from collections import OrderedDict
from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory
from multiprocessing import Pool
from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory
from pathlib import Path
from time import sleep, time
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload
Expand Down Expand Up @@ -264,6 +264,7 @@ def _download_url(url_info):
return f'Failed to download index.json: {src} to {dest}: {str(ex)}', ex
return dest, None


def _merge_partition_indices(partition_indices):
"""Function to be executed by each process to merge a subset of partition indices."""
shards = []
Expand All @@ -279,12 +280,15 @@ def _merge_partition_indices(partition_indices):
shards.extend(obj['shards'])
return shards


def _parallel_merge_partitions(partitions, n_processes=4):
"""Divide the list of partitions among multiple processes and merge them in parallel."""
with Pool(processes=n_processes) as pool:
# Split the list of partitions into N chunks where N is the number of processes
chunk_size = len(partitions) // n_processes + (len(partitions) % n_processes > 0)
partition_chunks = [partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size)]
partition_chunks = [
partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size)
]

# Process each chunk in parallel
results = pool.map(_merge_partition_indices, partition_chunks)
Expand Down Expand Up @@ -386,6 +390,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]
if not keep_local:
shutil.rmtree(cu.local, ignore_errors=True)


def _not_merged_index(index_file_path: str, out: str):
"""Check if index_file_path is the merged index at folder out.
Expand Down
40 changes: 16 additions & 24 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import urllib.parse
from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory
from typing import List, Optional, Tuple, Union, Sequence
from typing import List, Optional, Sequence, Tuple, Union

import pytest

Expand Down Expand Up @@ -194,9 +194,9 @@ def test_format_remote_index_files(scheme: str):
assert obj.scheme == scheme


@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3])
@pytest.mark.parametrize('keep_local', [True]) # , False])
@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://'])
@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3])
@pytest.mark.parametrize('keep_local', [True]) # , False])
@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://'])
def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool,
index_file_urls_pattern: int, scheme: str):
"""Validate the final merge index json for following patterns of index_file_urls:
Expand All @@ -206,14 +206,12 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc
4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all
5. All URLs are str (remote) -> download all
"""
from decimal import Decimal
import random
import string

from pyspark.sql import SparkSession
from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType

from streaming.base.converters import dataframeToMDS
import random
import string

def not_merged_index(index_file_path: str, out: str):
"""Check if index_file_path is the merged index at folder out."""
Expand All @@ -229,12 +227,12 @@ def not_merged_index(index_file_path: str, out: str):
def random_string(length=1000):
"""Generate a random string of fixed length."""
letters = string.ascii_letters + string.digits + string.punctuation + ' '
return ''.join(random.choice(letters) for i in range(length))
return ''.join(random.choice(letters) for _ in range(length))

# Generate a DataFrame with 10000 rows of random text
num_rows = 100
data = [(i, random_string(),random_string()) for i in range(num_rows)]
df = spark.createDataFrame(data, ["id", "name", "amount"])
data = [(i, random_string(), random_string()) for i in range(num_rows)]
df = spark.createDataFrame(data, ['id', 'name', 'amount'])

mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True}
dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs)
Expand Down Expand Up @@ -343,27 +341,24 @@ def flaky_function():
def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str, str]]],
out: Union[str, Tuple[str, str]],
keep_local: bool = True,
download_timeout: int = 60,
merge = False) -> None:
download_timeout: int = 60) -> None:
import logging
import shutil
import urllib.parse
from collections import OrderedDict
from pathlib import Path

from streaming.base.format.index import get_index_basename
from streaming.base.storage.download import download_file
from streaming.base.storage.upload import CloudUploader
from streaming.base.util import _not_merged_index, _format_remote_index_files
from streaming.base.format.index import get_index_basename
from collections import OrderedDict
import logging

if not index_file_urls or not out:
logger.warning('Either index_file_urls or out are None. ' +
'Need to specify both `index_file_urls` and `out`. ' + 'No index merged')
return

# This is the index json file name, e.g., it is index.json as of 0.6.0
index_basename = get_index_basename()

print('i am here 1.1')
cu = CloudUploader.get(out, keep_local=True, exist_ok=True)
print('i am here 1.2')

# Remove duplicates, and strip '/' from right if any
index_file_urls = list(OrderedDict.fromkeys(index_file_urls))
Expand Down Expand Up @@ -404,9 +399,6 @@ def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str

partitions.append(dest)

if not merge:
return

# merge shards from all index files
shards = []
for partition_index in partitions:
Expand Down

0 comments on commit feee52d

Please sign in to comment.