Skip to content

Commit

Permalink
Finally, JSON -> JSONL.
Browse files Browse the repository at this point in the history
  • Loading branch information
knighton committed Dec 15, 2023
1 parent 7cf0ef3 commit 3b742c0
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 132 deletions.
4 changes: 2 additions & 2 deletions benchmarks/backends/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from wurlitzer import pipes

from benchmarks.backends.datagen import generate
from streaming import CSVWriter, JSONWriter, MDSWriter
from streaming import CSVWriter, JSONLWriter, MDSWriter
from streaming.util.tabulation import Tabulator


Expand Down Expand Up @@ -108,7 +108,7 @@ def _write_jsonl(nums: List[int],
'num': 'int',
'txt': 'str',
}
with JSONWriter(out=root, columns=columns, size_limit=size_limit) as out:
with JSONLWriter(out=root, columns=columns, size_limit=size_limit) as out:
each_sample = zip(nums, txts)
if show_progress:
each_sample = tqdm(each_sample, total=len(nums), leave=False)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/samples/bench_and_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from numpy.typing import DTypeLike, NDArray
from tqdm import trange

from streaming import CSVWriter, JSONWriter, MDSWriter, StreamingDataset
from streaming import CSVWriter, JSONLWriter, MDSWriter, StreamingDataset


def parse_args() -> Namespace:
Expand Down Expand Up @@ -244,7 +244,7 @@ def bench(args: Namespace, bench_name: str, desc: str, generate: Callable,

format_infos = [
('mds', MDSWriter, args.mds_color),
('jsonl', JSONWriter, args.jsonl_color),
('jsonl', JSONLWriter, args.jsonl_color),
('csv', CSVWriter, args.csv_color),
]
format_infos = list(filter(lambda info: info[0] in formats, format_infos))
Expand Down
6 changes: 3 additions & 3 deletions streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from streaming._version import __version__
from streaming.dataloader import StreamingDataLoader
from streaming.dataset import StreamingDataset
from streaming.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter
from streaming.format import CSVWriter, JSONLWriter, MDSWriter, TSVWriter, XSVWriter
from streaming.local import LocalDataset
from streaming.stream import Stream
from streaming.util import clean_stale_shared_memory

__all__ = [
'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset',
'MDSWriter', 'TSVWriter', 'XSVWriter', 'clean_stale_shared_memory'
'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONLWriter',
'LocalDataset', 'MDSWriter', 'TSVWriter', 'XSVWriter', 'clean_stale_shared_memory'
]
21 changes: 17 additions & 4 deletions streaming/format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,38 @@
from typing import Any, Dict, Optional

from streaming.format.index import get_index_basename
from streaming.format.json import JSONShard, JSONWriter
from streaming.format.jsonl import JSONLShard, JSONLWriter
from streaming.format.mds import MDSShard, MDSWriter
from streaming.format.shard import FileInfo, Shard
from streaming.format.xsv import CSVShard, CSVWriter, TSVShard, TSVWriter, XSVShard, XSVWriter

__all__ = [
'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONWriter', 'MDSWriter', 'Shard',
'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONLWriter', 'MDSWriter', 'Shard',
'shard_from_json', 'TSVWriter', 'XSVWriter'
]

# Mapping of shard metadata dict "format" field to what type of Shard it is.
_shards = {
'csv': CSVShard,
'json': JSONShard,
'jsonl': JSONLShard,
'mds': MDSShard,
'tsv': TSVShard,
'xsv': XSVShard,
}


def _get_shard_class(format_name: str) -> Shard:
"""Get the associated Shard class given a Shard format name.
Args:
format_name (str): Shard format name.
"""
# JSONL shards were originally called JSON shards (while containing JSONL).
if format_name == 'json':
format_name = 'jsonl'
return _shards[format_name]


def shard_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Shard:
"""Create a shard from a JSON config.
Expand All @@ -37,5 +50,5 @@ def shard_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) ->
Shard: The loaded Shard.
"""
assert obj['version'] == 2
cls = _shards[obj['format']]
cls = _get_shard_class(obj['format'])
return cls.from_json(dirname, split, obj)
9 changes: 0 additions & 9 deletions streaming/format/json/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ Example:
"words": "str"
},
"compression": "zstd:7",
"format": "json",
"format": "jsonl",
"hashes": [
"sha1",
"xxh3_64"
],
"newline": "\n",
"raw_data": {
"basename": "shard.00000.json",
"basename": "shard.00000.jsonl",
"bytes": 1048546,
"hashes": {
"sha1": "bfb6509ba6f041726943ce529b36a1cb74e33957",
"xxh3_64": "0eb102a981b299eb"
}
},
"raw_meta": {
"basename": "shard.00000.json.meta",
"basename": "shard.00000.jsonl.meta",
"bytes": 53590,
"hashes": {
"sha1": "15ae80e002fe625b0b18f1a45058532ee867fa9b",
Expand All @@ -33,15 +33,15 @@ Example:
"size_limit": 1048576,
"version": 2,
"zip_data": {
"basename": "shard.00000.json.zstd",
"basename": "shard.00000.jsonl.zstd",
"bytes": 149268,
"hashes": {
"sha1": "7d45c600a71066ca8d43dbbaa2ffce50a91b735e",
"xxh3_64": "3d338d4826d4b5ac"
}
},
"zip_meta": {
"basename": "shard.00000.json.meta.zstd",
"basename": "shard.00000.jsonl.meta.zstd",
"bytes": 42180,
"hashes": {
"sha1": "f64477cca5d27fc3a0301eeb4452ef7310cbf670",
Expand Down
9 changes: 9 additions & 0 deletions streaming/format/jsonl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Streaming JSONL shards."""

from streaming.format.jsonl.shard import JSONLShard
from streaming.format.jsonl.writer import JSONLWriter

__all__ = ['JSONLShard', 'JSONLWriter']
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Check whether sample encoding is of supported JSON types."""
"""Check whether sample encoding is of supported JSONL types."""

from abc import ABC, abstractmethod
from typing import Any

__all__ = ['is_json_encoded', 'is_json_encoding']
__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding']


class Encoding(ABC):
"""Encoding of an object of JSON type."""
"""Encoding of an object of JSONL type."""

@classmethod
@abstractmethod
Expand Down Expand Up @@ -60,7 +60,7 @@ def is_encoded(cls, obj: Any) -> bool:
_encodings = {'str': Str, 'int': Int, 'float': Float}


def is_json_encoded(encoding: str, value: Any) -> bool:
def is_jsonl_encoded(encoding: str, value: Any) -> bool:
"""Get whether the given object is of this encoding type.
Args:
Expand All @@ -74,7 +74,7 @@ def is_json_encoded(encoding: str, value: Any) -> bool:
return cls.is_encoded(value)


def is_json_encoding(encoding: str) -> bool:
def is_jsonl_encoding(encoding: str) -> bool:
"""Get whether the given encoding is supported.
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Streaming JSON shard reading."""
"""Streaming JSONL shard reading."""

import json
import os
Expand All @@ -13,11 +13,11 @@

from streaming.format.shard import DualShard, FileInfo

__all__ = ['JSONShard']
__all__ = ['JSONLShard']


class JSONShard(DualShard):
"""Provides random access to the samples of a JSON shard.
class JSONLShard(DualShard):
"""Provides random access to the samples of a JSONL shard.
Args:
dirname (str): Local dataset directory.
Expand Down Expand Up @@ -68,7 +68,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S
obj (Dict[str, Any]): JSON object to load.
Returns:
Self: Loaded JSONShard.
Self: Loaded JSONLShard.
"""
args = deepcopy(obj)
# Version check.
Expand All @@ -77,9 +77,9 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S
f'Expected version 2.')
del args['version']
# Check format.
if args['format'] != 'json':
raise ValueError(f'Unsupported data format: {args["format"]}. ' +
f'Expected to be `json`.')
if args['format'] not in {'json', 'jsonl'}:
raise ValueError(f'Unsupported data format: got {args["format"]}, but expected ' +
f'"jsonl" (or "json").')
del args['format']
args['dirname'] = dirname
args['split'] = split
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Streaming JSON shard writing."""
"""Streaming JSONL shard writing."""

import json
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

from streaming.format.json.encodings import is_json_encoded, is_json_encoding
from streaming.format.jsonl.encodings import is_jsonl_encoded, is_jsonl_encoding
from streaming.format.writer import DualWriter

__all__ = ['JSONWriter']
__all__ = ['JSONLWriter']


class JSONWriter(DualWriter):
r"""Writes a streaming JSON dataset.
class JSONLWriter(DualWriter):
r"""Writes a streaming JSONL dataset.
Args:
columns (Dict[str, str]): Sample columns.
Expand Down Expand Up @@ -47,7 +47,7 @@ class JSONWriter(DualWriter):
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
"""

format = 'json'
format = 'jsonl'

def __init__(self,
*,
Expand All @@ -66,7 +66,7 @@ def __init__(self,
size_limit=size_limit,
**kwargs)
for encoding in columns.values():
assert is_json_encoding(encoding)
assert is_jsonl_encoding(encoding)

self.columns = columns
self.newline = newline
Expand All @@ -83,7 +83,7 @@ def encode_sample(self, sample: Dict[str, Any]) -> bytes:
obj = {}
for key, encoding in self.columns.items():
value = sample[key]
assert is_json_encoded(encoding, value)
assert is_jsonl_encoded(encoding, value)
obj[key] = value
text = json.dumps(obj, sort_keys=True) + self.newline
return text.encode('utf-8')
Expand Down
17 changes: 17 additions & 0 deletions streaming/format/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Set, Union

from typing_extensions import Self

from streaming.array import Array
from streaming.util.shorthand import normalize_bytes

Expand Down Expand Up @@ -61,6 +63,21 @@ def __init__(

self.file_pairs = []

@abstractmethod
@classmethod
def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Self:
"""Initialize from JSON object.
Args:
dirname (str): Local directory containing shards.
split (str, optional): Which dataset split to use, if any.
obj (Dict[str, Any]): JSON object to load.
Returns:
Self: Loaded Shard.
"""
raise NotImplementedError

def validate(self, allow_unsafe_types: bool) -> None:
"""Check whether this shard is acceptable to be part of some Stream.
Expand Down
Loading

0 comments on commit 3b742c0

Please sign in to comment.