diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 156763607..6de3c10c6 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -267,7 +267,7 @@ it on the command line. For example, pass `--data mlx-community/wikisql` to train on the pre-formatted WikiwSQL data. Otherwise, provide a mapping of keys in the dataset to the features MLX LM -expects. Use a YAML config to specify the Hugging Face dataset arguments. For +expects. Use a YAML config to specify the Hugging Face (HF) dataset arguments. For example: ``` @@ -279,11 +279,29 @@ hf_dataset: - Use `prompt_feature` and `completion_feature` to specify keys for a `completions` dataset. Use `text_feature` to specify the key for a `text` - dataset. + dataset. Use `chat_feature` to specify the key for a chat dataset. - To specify the train, valid, or test splits, set the corresponding `{train,valid,test}_split` argument. +You can specify a list of HF datasets using the `hf_datasets` (plural) configuration, which is a list of records +each with the same structure as above. For example: + +```yaml +hf_datasets: +- hf_dataset: + name: "Open-Orca/OpenOrca" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + prompt_feature: "question" + completion_feature: "response" +- hf_dataset: + name: "trl-lib/ultrafeedback_binarized" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + chat_feature: "chosen" +``` + - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 20b32effd..70551dcd0 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Callable, Dict, List, Union from transformers import PreTrainedTokenizer @@ -29,12 +29,18 @@ class ChatDataset(Dataset): https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + chat_key: str = "messages", + ): super().__init__(data) self._tokenizer = tokenizer + self._chat_key = chat_key def __getitem__(self, idx: int): - messages = self._data[idx]["messages"] + messages = self._data[idx][self._chat_key] text = self._tokenizer.apply_chat_template( messages, tools=self._data[idx].get("tools", None), @@ -75,6 +81,45 @@ def __getitem__(self, idx: int): ) return text + def get_prompt_and_completion(self, idx: int): + return self._data[idx][self._prompt_key], self._data[idx][self._completion_key] + + +class CompletionsDatasetCollection: + def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): + self.collection = data + + def __fetch_and_process_item__(self, idx: int, handler_fn: Callable): + iteration = iter(self.collection) + item = next(iteration) + + curr_idx = idx + + while True: + try: + if (curr_idx + 1) <= len(item): + return handler_fn(item, curr_idx) + else: + curr_idx -= len(item) + item = next(iteration) + except StopIteration: + raise IndexError(idx) + + def __getitem__(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + return dataset[index] + + return self.__fetch_and_process_item__(idx, getitem) + + def get_prompt_and_completion(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + return dataset.get_prompt_and_completion(index) + + return self.__fetch_and_process_item__(idx, getitem) + + def __len__(self): + return sum(map(len, self.collection)) + def create_dataset(data, tokenizer: PreTrainedTokenizer = None): sample = data[0] @@ -127,14 +172,14 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): + def create_hf_dataset( + dataset_name: Union[None, str], + text_feature: Union[None, str], + prompt_feature: Union[None, str], + completion_feature: Union[None, str], + chat_feature: Union[None, str], + split: str = None, + ): ds = datasets.load_dataset( dataset_name, split=split, @@ -142,31 +187,109 @@ def create_hf_dataset(split: str = None): ) if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) + elif chat_feature: + return ChatDataset(ds, tokenizer, chat_key=chat_feature) elif text_feature: - return Dataset(train_ds, text_key=text_feature) + return Dataset(ds, text_key=text_feature) else: raise ValueError( "Specify either a prompt and completion feature or a text " "feature for the Hugging Face dataset." ) - if args.train: + def get_hf_custom_features(hf_args): + return ( + hf_args.get("text_feature"), + hf_args.get("prompt_feature"), + hf_args.get("completion_feature"), + hf_args.get("chat_feature"), + ) + + def get_train_and_valid_splits(hf_args, ds_name): train_split = hf_args.get("train_split", "train[:80%]") valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) + text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args) + train = create_hf_dataset( + dataset_name=ds_name, + text_feature=text_f, + prompt_feature=prompt_f, + completion_feature=completion_f, + chat_feature=chat_f, + split=train_split, + ) + valid = create_hf_dataset( + dataset_name=ds_name, + text_feature=text_f, + prompt_feature=prompt_f, + completion_feature=completion_f, + chat_feature=chat_f, + split=valid_split, + ) + return train, valid + + if args.hf_datasets: + dataset_collection = args.hf_datasets + train_collection = [] + valid_collection = [] + test_collection = [] + for ds in dataset_collection: + hf_args = ds["hf_dataset"] + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature, prompt_feature, completion_feature, chat_f = ( + get_hf_custom_features(hf_args) + ) + if args.train: + train, valid = get_train_and_valid_splits(hf_args, dataset_name) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset( + dataset_name=dataset_name, + text_feature=text_feature, + prompt_feature=prompt_feature, + completion_feature=completion_feature, + chat_feature=chat_f, + split=hf_args.get("test_split"), + ) + else: + test = [] + train_collection.append(train) + valid_collection.append(valid) + test_collection.append(test) + return ( + CompletionsDatasetCollection(train_collection), + CompletionsDatasetCollection(valid_collection), + CompletionsDatasetCollection(test_collection), + ) else: - test = [] + hf_args = args.hf_dataset + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature, prompt_feature, completion_feature, chat_feature = ( + get_hf_custom_features(hf_args) + ) + if args.train: + train, valid = get_train_and_valid_splits(hf_args, dataset_name) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset( + dataset_name, + text_feature, + prompt_feature, + completion_feature, + chat_feature, + split=hf_args.get("test_split"), + ) + else: + test = [] return train, valid, test def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", None) is not None: + if getattr(args, "hf_dataset", None) is not None or getattr(args, "hf_datasets"): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data)