Skip to content

Commit

Permalink
Fix get_complete_frac monotonicity
Browse files Browse the repository at this point in the history
  • Loading branch information
dorian-K committed Mar 1, 2025
1 parent e1b1516 commit ad56ad5
Showing 1 changed file with 72 additions and 59 deletions.
131 changes: 72 additions & 59 deletions users/dorian_koch/datasets/MixingDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MixingDataset(CachedDataset2):
This means that, under some configurations, an epoch of one dataset may be seen many times.
If this is problematic maybe wrap it in a MultiEpochDataset? (does it support num_seqs? idk)
# TODO i overcomplicated some things in the design of this,
# TODO i overcomplicated some things in the design of this,
1. I hyper optimized for memory usage, which makes the code very messy
2. Because of 1, this doesnt scale well at all inside a MultiProcDataset
3. This supports random access, but I had to hack some stuff together because apparently other Datasets don't support that?
Expand All @@ -58,7 +58,9 @@ def __init__(
left_dataset: Dict[str, Any],
right_dataset: Dict[str, Any],
mixing_ratio: float = 0.5,
how_to_handle_end_of_data_from_one_dataset: Union[Literal["exception"], Literal["wrap_around"], Literal["early_exit"]] = "wrap_around",
how_to_handle_end_of_data_from_one_dataset: Union[
Literal["exception"], Literal["wrap_around"], Literal["early_exit"]
] = "wrap_around",
*,
data_key: str = "data",
control_dataset: str = "left",
Expand All @@ -82,7 +84,9 @@ def __init__(
self.left_dataset = init_dataset(left_dataset, parent_dataset=self)
self.right_dataset = init_dataset(right_dataset, parent_dataset=self)
self.control_dataset = self.left_dataset if control_dataset == "left" else self.right_dataset
self.num_inputs = make_hashable(self.left_dataset.num_inputs) # make_hashable normalizes lists/tuples to just tuples
self.num_inputs = make_hashable(
self.left_dataset.num_inputs
) # make_hashable normalizes lists/tuples to just tuples
self.num_outputs = make_hashable(self.left_dataset.num_outputs)
self.labels = self.left_dataset.labels
self.data_key = data_key
Expand All @@ -92,17 +96,19 @@ def __init__(
self._reset_params()

def _reset_params(self):
assert not (0 < self.right_dataset.num_seqs < 10) and not (0 < self.left_dataset.num_seqs < 10), "mixing can go wrong when one dataset has very few seqs"
assert not (0 < self.right_dataset.num_seqs < 10) and not (
0 < self.left_dataset.num_seqs < 10
), "mixing can go wrong when one dataset has very few seqs"
# left finishes first
lff = self.right_dataset.num_seqs * (1 + (1-self.mixing_ratio)/(self.mixing_ratio))
lff = self.right_dataset.num_seqs * (1 + (1 - self.mixing_ratio) / (self.mixing_ratio))
# right finishes first
rff = self.left_dataset.num_seqs * (1 + (self.mixing_ratio)/(1-self.mixing_ratio))
rff = self.left_dataset.num_seqs * (1 + (self.mixing_ratio) / (1 - self.mixing_ratio))
if self.how_to_handle_end_of_data_from_one_dataset in ["exception", "early_exit"]:
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(min(lff, rff)) # only one needs to finish
self.total_num_seqs_upper_bound = math.ceil(min(lff, rff)) # only one needs to finish
elif self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(max(lff, rff)) # both need to finish
self.total_num_seqs_upper_bound = math.ceil(max(lff, rff)) # both need to finish
else:
assert False

Expand All @@ -112,23 +118,24 @@ def _reset_params(self):
self.total_num_seqs_upper_bound *= 2

if self.total_num_seqs_upper_bound > 0:
print(f"MixingDataset init: {self.left_dataset.num_seqs} + {self.right_dataset.num_seqs}, upperbound={self.total_num_seqs_upper_bound}, mixingratio={self.mixing_ratio}", file=log.v4)
print(
f"MixingDataset init: {self.left_dataset.num_seqs} + {self.right_dataset.num_seqs}, upperbound={self.total_num_seqs_upper_bound}, mixingratio={self.mixing_ratio}",
file=log.v4,
)
else:
print("MixingDataset init: both datasets are empty", file=log.v4)
self._estimated_num_seqs = self.total_num_seqs_upper_bound
assert self.total_num_seqs_upper_bound < 2**31, "sequences do not fit into int32"
# 0 means left, 1 means right
self.bitset_chooser = Bitarray(self.total_num_seqs_upper_bound)
# cache indices to both datasets at a particular sequence index
self.index_cache = numpy.zeros(
((self.total_num_seqs_upper_bound + 1023) // 1024, 2), dtype=numpy.int32
)
self.index_cache = numpy.zeros(((self.total_num_seqs_upper_bound + 1023) // 1024, 2), dtype=numpy.int32)
# up until which point we have chosen
self.chooser_index = 0
self.is_chooser_done = False
self.chooser_childindices = [0, 0]
self.datasets_exhausted = [False, False]
self.datasets_loaded_until = [0, 0] # we need to _load_seqs the datasets
self.datasets_loaded_until = [0, 0] # we need to _load_seqs the datasets
# we will get out of balance while choosing, we will correct this by biasing the next choice
self.bias = 0.0
self.datalens = [0, 0]
Expand All @@ -141,20 +148,14 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
:param list[int]|None seq_order: List of corpus sequence indices, to set a predefined order.
"""
need_reinit = self.epoch is None or self.epoch != epoch
super().init_seq_order(
epoch=epoch, seq_list=seq_list, seq_order=seq_order
)
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not need_reinit:
return False

if seq_order is not None:
raise NotImplementedError(
"Predefined order via sequence indices for MixingDataset"
)
raise NotImplementedError("Predefined order via sequence indices for MixingDataset")
if seq_list is not None:
raise NotImplementedError(
"Predefined order via sequence tags for MixingDataset"
)
raise NotImplementedError("Predefined order via sequence tags for MixingDataset")
elif self.seq_ordering != "default":
raise NotImplementedError("seq_ordering %s" % self.seq_ordering)

Expand Down Expand Up @@ -196,14 +197,15 @@ def _make_sure_idx_is_loaded_in_child_ds(self, dataset_index, seq_idx):
def _run_seq_idx(self, seq_idx):
if seq_idx < self.chooser_index:
raise Exception("seq_idx < chooser_index")
assert seq_idx < self.total_num_seqs_upper_bound, "This assert fails only when the two datasets are very unbalanced, in the sense that one dataset has many long sequences while the other mostly has shorter once. Keep them on equal lengths on average please! Otherwise you need to somehow increase this upper bound (which will not cause issues, just eat more ram)"
assert (
seq_idx < self.total_num_seqs_upper_bound
), "This assert fails only when the two datasets are very unbalanced, in the sense that one dataset has many long sequences while the other mostly has shorter once. Keep them on equal lengths on average please! Otherwise you need to somehow increase this upper bound (which will not cause issues, just eat more ram)"
if self.is_chooser_done:
raise Exception("chooser is done. change attribute 'how_to_handle_end_of_data_from_one_dataset' to 'exception' if you want to know why (probably because early_exit)")
raise Exception(
"chooser is done. change attribute 'how_to_handle_end_of_data_from_one_dataset' to 'exception' if you want to know why (probably because early_exit)"
)

child_lens = [
self.left_dataset.num_seqs,
self.right_dataset.num_seqs
]
child_lens = [self.left_dataset.num_seqs, self.right_dataset.num_seqs]
while seq_idx >= self.chooser_index:
# we need to choose more
chooseRight = self.bias >= 0 and self.mixing_ratio > 0
Expand All @@ -214,20 +216,24 @@ def _run_seq_idx(self, seq_idx):
dataset_index = 1 if chooseRight else 0
chosen_dataset = self.right_dataset if chooseRight else self.left_dataset

if self.chooser_childindices[dataset_index] % child_lens[dataset_index] == 0 and self.chooser_childindices[dataset_index] > 0:
if (
self.chooser_childindices[dataset_index] % child_lens[dataset_index] == 0
and self.chooser_childindices[dataset_index] > 0
):
self.datasets_exhausted[dataset_index] = True
print(f"MixingDataset: ({dataset_index}) exhausted", file=log.v4)
self._print_progress()
c0 = self.chooser_childindices[0] / max(1, child_lens[0])
c1 = self.chooser_childindices[1] / max(1, child_lens[1])
print(f"MixingDataset: optimal mixing ratio = {(self.datalens[1] / c1) / max(1, self.datalens[0]/c0 + self.datalens[1]/c1)} (assuming uniform random distribution)", file=log.v4)
print(
f"MixingDataset: optimal mixing ratio = {(self.datalens[1] / c1) / max(1, self.datalens[0]/c0 + self.datalens[1]/c1)} (assuming uniform random distribution)",
file=log.v4,
)
if self.how_to_handle_end_of_data_from_one_dataset == "exception":
self.is_chooser_done = True
raise Exception(
"MixingDataset: end of dataset %d %r" % (dataset_index, chosen_dataset)
)
raise Exception("MixingDataset: end of dataset %d %r" % (dataset_index, chosen_dataset))
elif self.how_to_handle_end_of_data_from_one_dataset == "early_exit":
# the last decision is invalid (beyond the end of the dataset),
# the last decision is invalid (beyond the end of the dataset),
# but hopefully the other functions notice that we exited early and dont use the decision ...
self.is_chooser_done = True
break
Expand All @@ -241,21 +247,21 @@ def _run_seq_idx(self, seq_idx):
else:
assert False, f"{self.how_to_handle_end_of_data_from_one_dataset} not implemented"

self._make_sure_idx_is_loaded_in_child_ds(dataset_index, self.chooser_childindices[dataset_index] % child_lens[dataset_index])
datalen = MixingDataset._data_metric(chosen_dataset.get_data(
self.chooser_childindices[dataset_index] % child_lens[dataset_index], self.data_key
))
#print(f"({dataset_index}) datalen={datalen} shape={data.shape}")
self.bias -= (
(1 - self.mixing_ratio) if chooseRight else -self.mixing_ratio
) * max(datalen, 1)
self._make_sure_idx_is_loaded_in_child_ds(
dataset_index, self.chooser_childindices[dataset_index] % child_lens[dataset_index]
)
datalen = MixingDataset._data_metric(
chosen_dataset.get_data(
self.chooser_childindices[dataset_index] % child_lens[dataset_index], self.data_key
)
)
# print(f"({dataset_index}) datalen={datalen} shape={data.shape}")
self.bias -= ((1 - self.mixing_ratio) if chooseRight else -self.mixing_ratio) * max(datalen, 1)
self.datalens[dataset_index] += datalen
self.chooser_childindices[dataset_index] += 1
self.chooser_index += 1

assert not math.isnan(self.bias) and not math.isinf(
self.bias
) # this should never ever happen
assert not math.isnan(self.bias) and not math.isinf(self.bias) # this should never ever happen

if self.is_chooser_done:
return None
Expand All @@ -271,7 +277,7 @@ def _get_childindices_at_seq_idx(self, seq_idx):
if seq_idx >= self.chooser_index:
ran_ids = self._run_seq_idx(seq_idx)
if seq_idx >= self.chooser_index or ran_ids is None:
return None # we could not progress to the desired seq_idx, maybe early exit or exhaustion?
return None # we could not progress to the desired seq_idx, maybe early exit or exhaustion?

assert self.chooser_index == seq_idx + 1
# reverse last decision to get actual indices
Expand All @@ -290,15 +296,15 @@ def _get_childindices_at_seq_idx(self, seq_idx):
restore_from_idx = try_seq
restore_indices = result
break
# convert to list to avoid changin the index cache elements
# convert to list to avoid changing the index cache elements
restore_indices = list(restore_indices)

# replay the steps
while restore_from_idx < seq_idx:
if self.bitset_chooser.get(restore_from_idx):
restore_indices[1] += 1 # right
restore_indices[1] += 1 # right
else:
restore_indices[0] += 1 # left
restore_indices[0] += 1 # left
restore_from_idx += 1

return (restore_indices[0] % self.left_dataset.num_seqs, restore_indices[1] % self.right_dataset.num_seqs)
Expand All @@ -319,9 +325,7 @@ def _collect_single_seq(self, seq_idx):
dataset = self.left_dataset if dataset_idx == 0 else self.right_dataset
self._make_sure_idx_is_loaded_in_child_ds(dataset_idx, dataset_seq_idx)
seq_tag = dataset.get_tag(dataset_seq_idx)
features = {
k: dataset.get_data(dataset_seq_idx, k) for k in dataset.get_data_keys()
}
features = {k: dataset.get_data(dataset_seq_idx, k) for k in dataset.get_data_keys()}
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)

def is_less_than_num_seqs(self, seq_idx: int):
Expand All @@ -347,18 +351,24 @@ def get_target_list(self):
:rtype: list[str]
"""
return self.control_dataset.get_target_list()

def get_data_keys(self) -> List[str]:
"""data keys"""
return self.control_dataset.get_data_keys()

def _print_progress(self):
if self.left_dataset.num_seqs > 0:
print(f"MixingDataset: Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}, avg_datalen={self.datalens[0]/max(1, self.chooser_childindices[0])}", file=log.v4)
print(
f"MixingDataset: Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}, avg_datalen={self.datalens[0]/max(1, self.chooser_childindices[0])}",
file=log.v4,
)
else:
print("MixingDataset: Left dataset: empty", file=log.v4)
if self.right_dataset.num_seqs > 0:
print(f"MixingDataset: Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}, avg_datalen={self.datalens[1]/max(1, self.chooser_childindices[1])}", file=log.v4)
print(
f"MixingDataset: Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}, avg_datalen={self.datalens[1]/max(1, self.chooser_childindices[1])}",
file=log.v4,
)
else:
print("MixingDataset: Right dataset: empty", file=log.v4)

Expand Down Expand Up @@ -386,16 +396,19 @@ def get_data_dtype(self, key: str) -> str:
def is_data_sparse(self, key: str) -> bool:
"""is data sparse"""
return self.control_dataset.is_data_sparse(key)

def get_complete_frac(self, sorted_seq_idx: int, *, allow_only_lr_suitable: bool = False, **kwargs) -> Optional[float]:

def get_complete_frac(
self, sorted_seq_idx: int, *, allow_only_lr_suitable: bool = False, **kwargs
) -> Optional[float]:
assert self.left_dataset.num_seqs > 0 and self.right_dataset.num_seqs > 0
indices = self._get_childindices_at_seq_idx(sorted_seq_idx)
if indices is None:
return 1.0 # we are done
return 1.0 # we are done
frac_left = indices[0] / self.left_dataset.num_seqs
frac_right = indices[1] / self.right_dataset.num_seqs
if self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
return min(frac_left, frac_right)
# "early_exit" or "exception"
if any([self.datasets_exhausted[i] and indices[i] == 0 for i in range(2)]):
return 1.0 # index overflowed back to 0.0, so we just return 1.0
return max(frac_left, frac_right)

0 comments on commit ad56ad5

Please sign in to comment.