From 9df7bbbe3a252af2fa0b538821050dc5418c4312 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 19:11:54 -0500 Subject: [PATCH 1/9] Generalize HF datasets to a collection of HF dataasets via `datasets`, adds support for custom chat HF datasets (#1088), and fixes (#1087) --- llms/mlx_lm/tuner/datasets.py | 137 +++++++++++++++++++++++++++++----- 1 file changed, 117 insertions(+), 20 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 20b32effd..a0c16a280 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Union from transformers import PreTrainedTokenizer @@ -29,12 +29,18 @@ class ChatDataset(Dataset): https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + chat_key: str = "messages", + ): super().__init__(data) self._tokenizer = tokenizer + self._chat_key = chat_key def __getitem__(self, idx: int): - messages = self._data[idx]["messages"] + messages = self._data[idx][self._chat_key] text = self._tokenizer.apply_chat_template( messages, tools=self._data[idx].get("tools", None), @@ -76,6 +82,29 @@ def __getitem__(self, idx: int): return text +class CompletionsDatasetCollection: + def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): + self.collection = data + + def __getitem__(self, idx: int): + item = next(self.collection) + + curr_idx = idx + + while True: + try: + if (curr_idx + 1) < len(item): + return item[curr_idx] + else: + curr_idx -= len(item) + item = next(self.collection) + except StopIteration: + raise IndexError(idx) + + def __len__(self): + return sum(map(len, self.collection)) + + def create_dataset(data, tokenizer: PreTrainedTokenizer = None): sample = data[0] @@ -127,14 +156,14 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): + def create_hf_dataset( + dataset_name: Union[None, str], + text_feature: Union[None, str], + prompt_feature: Union[None, str], + completion_feature: Union[None, str], + chat_feature: Union[None, str], + split: str = None, + ): ds = datasets.load_dataset( dataset_name, split=split, @@ -142,25 +171,93 @@ def create_hf_dataset(split: str = None): ) if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) + elif chat_feature: + return ChatDataset(ds, tokenizer, chat_key=chat_feature) elif text_feature: - return Dataset(train_ds, text_key=text_feature) + return Dataset(ds, text_key=text_feature) else: raise ValueError( "Specify either a prompt and completion feature or a text " "feature for the Hugging Face dataset." ) - if args.train: + def get_hf_custom_features(hf_args): + return ( + hf_args.get("text_feature"), + hf_args.get("prompt_feature"), + hf_args.get("completion_feature"), + hf_args.get("chat_feature"), + ) + + def get_train_and_valid_splits(hf_args, ds_name): train_split = hf_args.get("train_split", "train[:80%]") valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) + text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args) + train = create_hf_dataset( + ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split + ) + valid = create_hf_dataset( + ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split + ) + return train, valid + + if args.datasets: + dataset_collection = args.hf_datasets + train_collection = [] + valid_collection = [] + test_collection = [] + for ds in dataset_collection: + hf_args = ds["hf_dataset"] + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature, prompt_feature, completion_feature, chat_f = ( + get_hf_custom_features(hf_args) + ) + if args.train: + train, valid = get_train_and_valid_splits(hf_args, dataset_name) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset( + dataset_name, + text_feature, + prompt_feature, + completion_feature, + chat_f, + split=hf_args.get("test_split"), + ) + else: + test = [] + train_collection.append(train) + valid_collection.append(valid) + test_collection.append(test) + return ( + CompletionsDatasetCollection(train_collection), + CompletionsDatasetCollection(valid_collection), + CompletionsDatasetCollection(test_collection), + ) else: - test = [] + hf_args = args.hf_dataset + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature, prompt_feature, completion_feature, chat_feature = ( + get_hf_custom_features(hf_args) + ) + if args.train: + train, valid = get_train_and_valid_splits(hf_args, dataset_name) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset( + dataset_name, + text_feature, + prompt_feature, + completion_feature, + chat_feature, + split=hf_args.get("test_split"), + ) + else: + test = [] return train, valid, test From 1f6c370690ada46df7f1d53b1e3002aad0dbd93d Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 19:41:09 -0500 Subject: [PATCH 2/9] Updates to LoRA documentation --- llms/mlx_lm/LORA.md | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 156763607..b67d87b81 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -267,7 +267,7 @@ it on the command line. For example, pass `--data mlx-community/wikisql` to train on the pre-formatted WikiwSQL data. Otherwise, provide a mapping of keys in the dataset to the features MLX LM -expects. Use a YAML config to specify the Hugging Face dataset arguments. For +expects. Use a YAML config to specify the Hugging Face (HF) dataset arguments. For example: ``` @@ -279,11 +279,30 @@ hf_dataset: - Use `prompt_feature` and `completion_feature` to specify keys for a `completions` dataset. Use `text_feature` to specify the key for a `text` - dataset. + dataset. Use `chat_feature` to specify the key for a chat dataset. - To specify the train, valid, or test splits, set the corresponding `{train,valid,test}_split` argument. +You can specify a list of HF datasets using the `hf_datasets` (plural) configuration, which is a list of records +each with the same structure as above. For example: + +```yaml +hf_datasets: [ + "hf_dataset": + name: "Open-Orca/OpenOrca" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + prompt_feature: "question" + completion_feature: "response", + "hf_dataset": + name: "trl-lib/ultrafeedback_binarized" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + chat_feature: "chosen" +] +``` + - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). From c72122064af1d6a510203c87d025a81979d55bd4 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 20:00:35 -0500 Subject: [PATCH 3/9] Fixes to config format in documentattion --- llms/mlx_lm/LORA.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index b67d87b81..6de3c10c6 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -288,19 +288,18 @@ You can specify a list of HF datasets using the `hf_datasets` (plural) configura each with the same structure as above. For example: ```yaml -hf_datasets: [ - "hf_dataset": +hf_datasets: +- hf_dataset: name: "Open-Orca/OpenOrca" train_split: "train[:90%]" valid_split: "train[-10%:]" prompt_feature: "question" - completion_feature: "response", - "hf_dataset": + completion_feature: "response" +- hf_dataset: name: "trl-lib/ultrafeedback_binarized" train_split: "train[:90%]" valid_split: "train[-10%:]" chat_feature: "chosen" -] ``` - Arguments specified in `config` will be passed as keyword arguments to From 04cf93df55c2b578f3e2ffe94f7d2732101fc5b7 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 20:04:15 -0500 Subject: [PATCH 4/9] Fixes to references to hf_datasets --- llms/mlx_lm/tuner/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index a0c16a280..3b442c6ab 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -201,7 +201,7 @@ def get_train_and_valid_splits(hf_args, ds_name): ) return train, valid - if args.datasets: + if args.hf_datasets: dataset_collection = args.hf_datasets train_collection = [] valid_collection = [] @@ -263,7 +263,7 @@ def get_train_and_valid_splits(hf_args, ds_name): def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", None) is not None: + if getattr(args, "hf_dataset", None) is not None or getattr(args, "hf_datasets"): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) From e477060a00e7d2fdc5f6d44f629a6fd17eef684b Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 20:26:15 -0500 Subject: [PATCH 5/9] Fix keyword argument invokation --- llms/mlx_lm/tuner/datasets.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 3b442c6ab..c75171e5c 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -194,10 +194,20 @@ def get_train_and_valid_splits(hf_args, ds_name): valid_split = hf_args.get("valid_split", "train[-10%:]") text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args) train = create_hf_dataset( - ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split + dataset_name=ds_name, + text_feature=text_f, + prompt_feature=prompt_f, + completion_feature=completion_f, + chat_feature=chat_f, + split=train_split, ) valid = create_hf_dataset( - ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split + dataset_name=ds_name, + text_feature=text_f, + prompt_feature=prompt_f, + completion_feature=completion_f, + chat_feature=chat_f, + split=valid_split, ) return train, valid @@ -219,11 +229,11 @@ def get_train_and_valid_splits(hf_args, ds_name): train, valid = [], [] if args.test: test = create_hf_dataset( - dataset_name, - text_feature, - prompt_feature, - completion_feature, - chat_f, + dataset_name=dataset_name, + text_feature=text_feature, + prompt_feature=prompt_feature, + completion_feature=completion_feature, + chat_feature=chat_f, split=hf_args.get("test_split"), ) else: From 24f40c3b8d3fbf2d8da76d69ce1c853bf8449058 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 20:30:47 -0500 Subject: [PATCH 6/9] Fix iteration over HF dataset collection --- llms/mlx_lm/tuner/datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index c75171e5c..1a86fb9fb 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -87,7 +87,8 @@ def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): self.collection = data def __getitem__(self, idx: int): - item = next(self.collection) + iteration = iter(self.collection) + item = next(iteration) curr_idx = idx @@ -97,7 +98,7 @@ def __getitem__(self, idx: int): return item[curr_idx] else: curr_idx -= len(item) - item = next(self.collection) + item = next(iteration) except StopIteration: raise IndexError(idx) From 78b24a2375bb84408b2ca71172e3f0c68afb245e Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 20:36:55 -0500 Subject: [PATCH 7/9] Fix index calculation --- llms/mlx_lm/tuner/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 1a86fb9fb..da2e6f220 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -94,7 +94,7 @@ def __getitem__(self, idx: int): while True: try: - if (curr_idx + 1) < len(item): + if (curr_idx + 1) <= len(item): return item[curr_idx] else: curr_idx -= len(item) From e45ce38f8699961141a38d4b7399e8c184bd78a8 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Wed, 6 Nov 2024 12:53:54 -0500 Subject: [PATCH 8/9] Add ability to fetch raw prompt and completion text from completion datasets --- llms/mlx_lm/tuner/datasets.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index da2e6f220..3130d9f96 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List, Union +from typing import Callable, Dict, List, Union from transformers import PreTrainedTokenizer @@ -81,12 +81,15 @@ def __getitem__(self, idx: int): ) return text + def get_prompt_and_completion(self, idx: int): + return self._data[idx][self._prompt_key], self._data[idx][self._completion_key] + class CompletionsDatasetCollection: def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): self.collection = data - def __getitem__(self, idx: int): + def __fetch_and_process_item__(self, idx: int, handler_fn: Callable): iteration = iter(self.collection) item = next(iteration) @@ -95,13 +98,25 @@ def __getitem__(self, idx: int): while True: try: if (curr_idx + 1) <= len(item): - return item[curr_idx] + return handler_fn(item, curr_idx) else: curr_idx -= len(item) item = next(iteration) except StopIteration: raise IndexError(idx) + def __getitem__(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + return dataset[index] + + return self.__fetch_and_process_item__(idx, getitem) + + def get_prompt_and_completion(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + dataset.get_prompt_and_completion(index) + + return self.__fetch_and_process_item__(idx, getitem) + def __len__(self): return sum(map(len, self.collection)) From 90e2da881ceff53b1220bda39bee39f78718ec29 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Wed, 6 Nov 2024 12:58:00 -0500 Subject: [PATCH 9/9] Minor fix --- llms/mlx_lm/tuner/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 3130d9f96..70551dcd0 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -113,7 +113,7 @@ def getitem(dataset: CompletionsDataset, index: int): def get_prompt_and_completion(self, idx: int): def getitem(dataset: CompletionsDataset, index: int): - dataset.get_prompt_and_completion(index) + return dataset.get_prompt_and_completion(index) return self.__fetch_and_process_item__(idx, getitem)