Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize HF datasets to a collection of HF datasets via hf_datasets #1090

Closed
wants to merge 11 commits into from
22 changes: 20 additions & 2 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data.

Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
expects. Use a YAML config to specify the Hugging Face (HF) dataset arguments. For
example:

```
Expand All @@ -279,11 +279,29 @@ hf_dataset:

- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
dataset. Use `chat_feature` to specify the key for a chat dataset.

- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.

You can specify a list of HF datasets using the `hf_datasets` (plural) configuration, which is a list of records
each with the same structure as above. For example:

```yaml
hf_datasets:
- hf_dataset:
name: "Open-Orca/OpenOrca"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
prompt_feature: "question"
completion_feature: "response"
- hf_dataset:
name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
chat_feature: "chosen"
```

- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).

Expand Down
165 changes: 144 additions & 21 deletions llms/mlx_lm/tuner/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List
from typing import Callable, Dict, List, Union

from transformers import PreTrainedTokenizer

Expand Down Expand Up @@ -29,12 +29,18 @@ class ChatDataset(Dataset):
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""

def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
chat_key: str = "messages",
):
super().__init__(data)
self._tokenizer = tokenizer
self._chat_key = chat_key

def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
messages = self._data[idx][self._chat_key]
text = self._tokenizer.apply_chat_template(
messages,
tools=self._data[idx].get("tools", None),
Expand Down Expand Up @@ -75,6 +81,45 @@ def __getitem__(self, idx: int):
)
return text

def get_prompt_and_completion(self, idx: int):
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]


class CompletionsDatasetCollection:
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
self.collection = data

def __fetch_and_process_item__(self, idx: int, handler_fn: Callable):
iteration = iter(self.collection)
item = next(iteration)

curr_idx = idx

while True:
try:
if (curr_idx + 1) <= len(item):
return handler_fn(item, curr_idx)
else:
curr_idx -= len(item)
item = next(iteration)
except StopIteration:
raise IndexError(idx)

def __getitem__(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
return dataset[index]

return self.__fetch_and_process_item__(idx, getitem)

def get_prompt_and_completion(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
return dataset.get_prompt_and_completion(index)

return self.__fetch_and_process_item__(idx, getitem)

def __len__(self):
return sum(map(len, self.collection))


def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
sample = data[0]
Expand Down Expand Up @@ -127,46 +172,124 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets

hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")

def create_hf_dataset(split: str = None):
def create_hf_dataset(
dataset_name: Union[None, str],
text_feature: Union[None, str],
prompt_feature: Union[None, str],
completion_feature: Union[None, str],
chat_feature: Union[None, str],
split: str = None,
):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif chat_feature:
return ChatDataset(ds, tokenizer, chat_key=chat_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
return Dataset(ds, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
)

if args.train:
def get_hf_custom_features(hf_args):
return (
hf_args.get("text_feature"),
hf_args.get("prompt_feature"),
hf_args.get("completion_feature"),
hf_args.get("chat_feature"),
)

def get_train_and_valid_splits(hf_args, ds_name):
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args)
train = create_hf_dataset(
dataset_name=ds_name,
text_feature=text_f,
prompt_feature=prompt_f,
completion_feature=completion_f,
chat_feature=chat_f,
split=train_split,
)
valid = create_hf_dataset(
dataset_name=ds_name,
text_feature=text_f,
prompt_feature=prompt_f,
completion_feature=completion_f,
chat_feature=chat_f,
split=valid_split,
)
return train, valid

if args.hf_datasets:
dataset_collection = args.hf_datasets
train_collection = []
valid_collection = []
test_collection = []
for ds in dataset_collection:
hf_args = ds["hf_dataset"]
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature, prompt_feature, completion_feature, chat_f = (
get_hf_custom_features(hf_args)
)
if args.train:
train, valid = get_train_and_valid_splits(hf_args, dataset_name)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(
dataset_name=dataset_name,
text_feature=text_feature,
prompt_feature=prompt_feature,
completion_feature=completion_feature,
chat_feature=chat_f,
split=hf_args.get("test_split"),
)
else:
test = []
train_collection.append(train)
valid_collection.append(valid)
test_collection.append(test)
return (
CompletionsDatasetCollection(train_collection),
CompletionsDatasetCollection(valid_collection),
CompletionsDatasetCollection(test_collection),
)
else:
test = []
hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature, prompt_feature, completion_feature, chat_feature = (
get_hf_custom_features(hf_args)
)
if args.train:
train, valid = get_train_and_valid_splits(hf_args, dataset_name)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(
dataset_name,
text_feature,
prompt_feature,
completion_feature,
chat_feature,
split=hf_args.get("test_split"),
)
else:
test = []

return train, valid, test


def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
if getattr(args, "hf_dataset", None) is not None or getattr(args, "hf_datasets"):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
Expand Down