Skip to content

Commit

Permalink
feat: add sample function (#447)
Browse files Browse the repository at this point in the history
* add sample function

* `eval_only` supports summarization
  • Loading branch information
Dobiichi-Origami authored Apr 15, 2024
1 parent 683aef4 commit b4379c4
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/qianfan/dataset/data_source/chunk_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
self.file_path = file_path
self.fd = open(file_path, mode="r", encoding=encoding())

self.ijson_object = ijson.items(self.fd, element_json_path)
self.ijson_object = ijson.items(self.fd, element_json_path, use_float=True)

def _get_an_element(self, index: int) -> Any:
return next(self.ijson_object)
Expand Down
32 changes: 30 additions & 2 deletions python/qianfan/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,36 @@ def take_slice(
"""
return super().take_slice(start, end, should_create_new_obj, **kwargs)

@_online_except_decorator
def sample(
self,
sample_number: int,
start: int = 0,
end: int = -1,
should_create_new_obj: bool = False,
**kwargs: Any,
) -> Self:
"""
take random slice in dataset
Args:
sample_number (int):
how many entries should be sampled
start (int):
where the sample part starts
end (int):
where the sample part ends
should_create_new_obj (bool):
should a new object be created when mapping terminates.
Default to False. In some cases, you may want to set
this value to True
**kwargs (Any):
other arguments
"""
return super().sample(
sample_number, start, end, should_create_new_obj, **kwargs
)

def __getitem__(self, key: Any) -> Any:
if (
isinstance(key, int)
Expand Down Expand Up @@ -2020,8 +2050,6 @@ def show_processed_statistics(
MeanMethod(),
MinMethod(),
MaxMethod(),
QuantileMethod(q=0.2),
QuantileMethod(q=0.5),
QuantileMethod(q=0.8),
QuantileMethod(q=0.9),
]
Expand Down
45 changes: 45 additions & 0 deletions python/qianfan/dataset/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
wrapper for pyarrow.Table
"""
import random
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -519,6 +520,38 @@ def take_slice(self, start: int = 0, end: int = -1) -> PyarrowTable:

return self.table.slice(start, end - start + 1)

def sample(
self,
sample_number: int,
start: int = 0,
end: int = -1,
) -> PyarrowTable:
if start < 0:
err_msg = f"start index is smaller than 0: {start}"
log_error(err_msg)
raise ValueError(err_msg)

if end >= self.table.num_rows:
err_msg = (
f"end index {end} is bigger than table size: {self.table.num_rows}"
)
log_error(err_msg)
raise ValueError(err_msg)

if end < 0:
end = self.table.num_rows - 1

if sample_number < 0:
err_msg = f"can't sample {sample_number} entries"
log_error(err_msg)
raise ValueError(err_msg)

if sample_number >= self.table.num_rows:
return self.table

numbers = random.sample(range(start, end + 1), sample_number)
return self.table.take(numbers)


class _PyarrowColumnManipulator(BaseModel, Addable, Listable, Processable):
"""handler for processing of pyarrow table column"""
Expand Down Expand Up @@ -1193,6 +1226,18 @@ def take_slice(
result_ds = manipulator.take_slice(start, end)
return self._create_new_obj(result_ds, should_create_new_obj)

def sample(
self,
sample_number: int,
start: int = 0,
end: int = -1,
should_create_new_obj: bool = False,
**kwargs: Any,
) -> Self:
manipulator = self._row_op()
result_ds = manipulator.sample(sample_number, start, end)
return self._create_new_obj(result_ds, should_create_new_obj)

def col_map(
self,
op: Callable[[Any], Any],
Expand Down
15 changes: 14 additions & 1 deletion python/qianfan/evaluation/evaluation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,20 @@ def eval_only(
tmp_ds = Dataset.create_from_pyobj(
self._run_evaluator_locally(dataset, **kwargs)
)
return EvaluationResult(result_dataset=dataset.col_append(tmp_ds.col_list()))

assert self.local_evaluators

summarization_dict: Dict[str, Any] = {}

for evaluator in self.local_evaluators:
summarization = evaluator.summarize(tmp_ds)
if summarization:
summarization_dict.update(summarization)

return EvaluationResult(
metrics=summarization_dict,
result_dataset=dataset.col_append(tmp_ds.col_list()),
)

def eval(
self,
Expand Down

0 comments on commit b4379c4

Please sign in to comment.