Skip to content

Commit

Permalink
generalize
Browse files Browse the repository at this point in the history
  • Loading branch information
dorian-K committed Mar 7, 2025
1 parent dacc342 commit 4536e91
Showing 1 changed file with 46 additions and 26 deletions.
72 changes: 46 additions & 26 deletions users/dorian_koch/datasets/MixingDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def get(self, i: int) -> bool:
def __len__(self):
return self.size

THandleEndOfData = Literal["exception", "wrap_around", "early_exit"]

class MixingDataset(CachedDataset2):
"""
Expand All @@ -58,9 +59,8 @@ def __init__(
left_dataset: Dict[str, Any],
right_dataset: Dict[str, Any],
mixing_ratio: float = 0.5,
how_to_handle_end_of_data_from_one_dataset: Union[
Literal["exception"], Literal["wrap_around"], Literal["early_exit"]
] = "wrap_around",
how_to_handle_end_of_data: Optional[List[THandleEndOfData]] = None,
how_to_handle_end_of_data_from_one_dataset: Optional[THandleEndOfData] = "wrap_around", # deprecated
*,
data_key: str = "data",
control_dataset: str = "left",
Expand All @@ -71,7 +71,7 @@ def __init__(
:param right_dataset:
:param mixing_ratio: probability to choose the right dataset
:param data_key: key for the mixing process, mixing considers the size of the data
:param how_to_handle_end_of_data_from_one_dataset: what to do when one dataset is exhausted
:param how_to_handle_end_of_data_from_one_dataset: (deprecated) what to do when one dataset is exhausted
exception: raise an exception, this should practically never be used in training
wrap_around: wrap around to the beginning of the dataset that is exhausted. Terminate when both datasets have terminated at least once.
early_exit: end epoch when one dataset has been exhausted
Expand All @@ -80,7 +80,14 @@ def __init__(
super().__init__(**kwargs)
assert 0.0 <= mixing_ratio <= 1.0
self.mixing_ratio = mixing_ratio
self.how_to_handle_end_of_data_from_one_dataset = how_to_handle_end_of_data_from_one_dataset
if how_to_handle_end_of_data_from_one_dataset is not None:
assert how_to_handle_end_of_data is None
how_to_handle_end_of_data = [how_to_handle_end_of_data_from_one_dataset] * 2
else:
assert how_to_handle_end_of_data is not None
assert len(how_to_handle_end_of_data) == 2

self.how_to_handle_end_of_data = how_to_handle_end_of_data
self.left_dataset = init_dataset(left_dataset, parent_dataset=self)
self.right_dataset = init_dataset(right_dataset, parent_dataset=self)
self.control_dataset = self.left_dataset if control_dataset == "left" else self.right_dataset
Expand All @@ -99,18 +106,25 @@ def _reset_params(self):
assert not (0 < self.right_dataset.num_seqs < 10) and not (
0 < self.left_dataset.num_seqs < 10
), "mixing can go wrong when one dataset has very few seqs"
# left finishes first
lff = self.right_dataset.num_seqs * (1 + (1 - self.mixing_ratio) / (self.mixing_ratio))
# right finishes first
rff = self.left_dataset.num_seqs * (1 + (self.mixing_ratio) / (1 - self.mixing_ratio))
if self.how_to_handle_end_of_data_from_one_dataset in ["exception", "early_exit"]:
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(min(lff, rff)) # only one needs to finish
elif self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(max(lff, rff)) # both need to finish
else:
assert False
# left terminates epoch
lff = self.left_dataset.num_seqs * (1 + (self.mixing_ratio) / (1 - self.mixing_ratio))
# right terminates epoch
rff = self.right_dataset.num_seqs * (1 + (1 - self.mixing_ratio) / (self.mixing_ratio))
finish_seqs_arr = [lff, rff]

assert len(self.how_to_handle_end_of_data) == 2, "this logic goes wrong with more than two datasets"

if all(how in ["exception", "early_exit"] for how in self.how_to_handle_end_of_data):
# epoch terminates if any dataset finishes
self.total_num_seqs_upper_bound = math.ceil(min(finish_seqs_arr))
elif all(how == "wrap_around" for how in self.how_to_handle_end_of_data):
# epoch terminates if all datasets finish
self.total_num_seqs_upper_bound = math.ceil(max(finish_seqs_arr))
else: # mix
for i, how in enumerate(self.how_to_handle_end_of_data):
if how == "wrap_around":
finish_seqs_arr[i] = float("inf")
self.total_num_seqs_upper_bound = math.ceil(min(finish_seqs_arr))

assert not math.isnan(self.total_num_seqs_upper_bound) and not math.isinf(self.total_num_seqs_upper_bound)
# for good measure
Expand Down Expand Up @@ -202,7 +216,7 @@ def _run_seq_idx(self, seq_idx):
), "This assert fails only when the two datasets are very unbalanced, in the sense that one dataset has many long sequences while the other mostly has shorter once. Keep them on equal lengths on average please! Otherwise you need to somehow increase this upper bound (which will not cause issues, just eat more ram)"
if self.is_chooser_done:
raise Exception(
"chooser is done. change attribute 'how_to_handle_end_of_data_from_one_dataset' to 'exception' if you want to know why (probably because early_exit)"
"chooser is done. change attribute 'how_to_handle_end_of_data' to 'exception' if you want to know why (probably because early_exit)"
)

child_lens = [self.left_dataset.num_seqs, self.right_dataset.num_seqs]
Expand All @@ -229,23 +243,23 @@ def _run_seq_idx(self, seq_idx):
f"MixingDataset: optimal mixing ratio = {(self.datalens[1] / c1) / max(1, self.datalens[0]/c0 + self.datalens[1]/c1)} (assuming uniform random distribution)",
file=log.v4,
)
if self.how_to_handle_end_of_data_from_one_dataset == "exception":
if self.how_to_handle_end_of_data[dataset_index] == "exception":
self.is_chooser_done = True
raise Exception("MixingDataset: end of dataset %d %r" % (dataset_index, chosen_dataset))
elif self.how_to_handle_end_of_data_from_one_dataset == "early_exit":
elif self.how_to_handle_end_of_data[dataset_index] == "early_exit":
# the last decision is invalid (beyond the end of the dataset),
# but hopefully the other functions notice that we exited early and dont use the decision ...
self.is_chooser_done = True
break
elif self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
elif self.how_to_handle_end_of_data[dataset_index] == "wrap_around":
# im not sure of the logic inside the datasets and whether it keeps data that has been loaded before indefinitely,
# so just start loading them at the beginning again
if all(self.datasets_exhausted):
self.is_chooser_done = True
break
# the modulo operator below will wrap around
else:
assert False, f"{self.how_to_handle_end_of_data_from_one_dataset} not implemented"
assert False, f"{self.how_to_handle_end_of_data[dataset_index]} not implemented"

self._make_sure_idx_is_loaded_in_child_ds(
dataset_index, self.chooser_childindices[dataset_index] % child_lens[dataset_index]
Expand Down Expand Up @@ -412,7 +426,13 @@ def get_complete_frac(
return 1.0 # we are done
frac_left = indices[0] / self.left_dataset.num_seqs
frac_right = indices[1] / self.right_dataset.num_seqs
if self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
return min(frac_left, frac_right)
# "early_exit" or "exception"
return min(1.0, max(frac_left, frac_right))
fracs = [frac_left, frac_right]
if all(how == "wrap_around" for how in self.how_to_handle_end_of_data):
return min(fracs)
if all(how in ["exception", "early_exit"] for how in self.how_to_handle_end_of_data):
return min(1.0, max(fracs))
# mix
for i, how in enumerate(self.how_to_handle_end_of_data):
if how == "wrap_around":
fracs[i] = 0.0
return min(1.0, max(fracs))

0 comments on commit 4536e91

Please sign in to comment.