|
| 1 | +package sttp.openai.requests.finetuning |
| 2 | + |
| 3 | +import sttp.openai.OpenAIExceptions.OpenAIException.DeserializationOpenAIException |
| 4 | +import sttp.openai.json.SnakePickle |
| 5 | +import sttp.openai.requests.finetuning.FineTuningRequestBody.Integration.Integration |
| 6 | +import sttp.openai.requests.finetuning.FineTuningRequestBody.Method.Method |
| 7 | +import ujson.Str |
| 8 | + |
| 9 | +object FineTuningRequestBody { |
| 10 | + |
| 11 | + /** @param model |
| 12 | + * The name of the model to fine-tune. You can select one of the supported models |
| 13 | + * [[https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned]]. |
| 14 | + * @param trainingFile |
| 15 | + * The ID of an uploaded file that contains training data. See upload file for how to upload a file. Your dataset must be formatted as |
| 16 | + * a JSONL file. Additionally, you must upload your file with the purpose fine-tune. The contents of the file should differ depending |
| 17 | + * on if the model uses the chat, completions format, or if the fine-tuning method uses the preference format. See the fine-tuning |
| 18 | + * guide for more details. |
| 19 | + * @param suffix |
| 20 | + * A string of up to 64 characters that will be added to your fine-tuned model name. For example, a suffix of "custom-model-name" would |
| 21 | + * produce a model name like ft:gpt-4o-mini:openai:custom-model-name:7p4lURel. |
| 22 | + * @param validationFile |
| 23 | + * The ID of an uploaded file that contains validation data. If you provide this file, the data is used to generate validation metrics |
| 24 | + * periodically during fine-tuning. These metrics can be viewed in the fine-tuning results file. The same data should not be present in |
| 25 | + * both train and validation files. Your dataset must be formatted as a JSONL file. You must upload your file with the purpose |
| 26 | + * fine-tune. See the fine-tuning guide for more details. |
| 27 | + * @param integrations |
| 28 | + * A list of integrations to enable for your fine-tuning job. |
| 29 | + * @param seed |
| 30 | + * The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but |
| 31 | + * may differ in rare cases. If a seed is not specified, one will be generated for you. |
| 32 | + * @param method |
| 33 | + * The method used for fine-tuning. |
| 34 | + */ |
| 35 | + case class FineTuningRequestBody( |
| 36 | + model: FineTuningModel, |
| 37 | + trainingFile: String, |
| 38 | + suffix: Option[String] = None, |
| 39 | + validationFile: Option[String] = None, |
| 40 | + integrations: Option[Seq[Integration]] = None, |
| 41 | + seed: Option[Int] = None, |
| 42 | + method: Option[Method] = None |
| 43 | + ) |
| 44 | + |
| 45 | + sealed abstract class Type(val value: String) |
| 46 | + |
| 47 | + object Type { |
| 48 | + def typeRW(byTypeValue: Map[String, Type]): SnakePickle.ReadWriter[Type] = SnakePickle |
| 49 | + .readwriter[ujson.Value] |
| 50 | + .bimap[Type]( |
| 51 | + `type` => SnakePickle.writeJs(`type`.value), |
| 52 | + jsonValue => |
| 53 | + SnakePickle.read[ujson.Value](jsonValue) match { |
| 54 | + case Str(value) => |
| 55 | + byTypeValue.get(value) match { |
| 56 | + case Some(t) => t |
| 57 | + case None => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $value")) |
| 58 | + } |
| 59 | + case e => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $e")) |
| 60 | + } |
| 61 | + ) |
| 62 | + } |
| 63 | + |
| 64 | + object Method { |
| 65 | + |
| 66 | + object MethodType { |
| 67 | + case object Supervised extends Type("supervised") |
| 68 | + |
| 69 | + case object Dpo extends Type("dpo") |
| 70 | + |
| 71 | + private val values: Set[Type] = Set(Supervised, Dpo) |
| 72 | + |
| 73 | + private val byTypeValue = values.map(`type` => `type`.value -> `type`).toMap |
| 74 | + |
| 75 | + implicit val typeRW: SnakePickle.ReadWriter[Type] = Type.typeRW(byTypeValue) |
| 76 | + } |
| 77 | + |
| 78 | + /** @param batchSize |
| 79 | + * Number of examples in each batch. A larger batch size means that model parameters are updated less frequently, but with lower |
| 80 | + * variance. |
| 81 | + * @param learningRateMultiplier |
| 82 | + * Scaling factor for the learning rate. A smaller learning rate may be useful to avoid overfitting. |
| 83 | + * @param nEpochs |
| 84 | + * The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. |
| 85 | + * @param beta |
| 86 | + * The beta value for the DPO method. A higher beta value will increase the weight of the penalty between the policy and reference |
| 87 | + * model. |
| 88 | + */ |
| 89 | + case class Hyperparameters( |
| 90 | + batchSize: Option[Int] = None, |
| 91 | + learningRateMultiplier: Option[Float] = None, |
| 92 | + nEpochs: Option[Int] = None, |
| 93 | + beta: Option[Float] = None |
| 94 | + ) |
| 95 | + |
| 96 | + /** @param hyperparameters |
| 97 | + * The hyperparameters used for the fine-tuning job. |
| 98 | + */ |
| 99 | + case class Supervised( |
| 100 | + hyperparameters: Option[Hyperparameters] = None |
| 101 | + ) |
| 102 | + |
| 103 | + /** @param hyperparameters |
| 104 | + * The hyperparameters used for the fine-tuning job. |
| 105 | + */ |
| 106 | + case class Dpo( |
| 107 | + hyperparameters: Option[Hyperparameters] = None |
| 108 | + ) |
| 109 | + |
| 110 | + /** @param `type` |
| 111 | + * The type of method. Is either supervised or dpo. |
| 112 | + * @param supervised |
| 113 | + * Configuration for the supervised fine-tuning method. |
| 114 | + * @param dpo |
| 115 | + * Configuration for the DPO fine-tuning method. |
| 116 | + */ |
| 117 | + case class Method( |
| 118 | + `type`: Option[Type] = None, |
| 119 | + supervised: Option[Supervised] = None, |
| 120 | + dpo: Option[Dpo] = None |
| 121 | + ) |
| 122 | + } |
| 123 | + |
| 124 | + object Integration { |
| 125 | + |
| 126 | + object IntegrationType { |
| 127 | + case object Wandb extends Type("wandb") |
| 128 | + |
| 129 | + private val values: Set[Type] = Set(Wandb) |
| 130 | + |
| 131 | + private val byTypeValue = values.map(`type` => `type`.value -> `type`).toMap |
| 132 | + |
| 133 | + implicit val typeRW: SnakePickle.ReadWriter[Type] = Type.typeRW(byTypeValue) |
| 134 | + } |
| 135 | + |
| 136 | + /** @param project |
| 137 | + * The name of the project that the new run will be created under. |
| 138 | + * @param name |
| 139 | + * A display name to set for the run. If not set, we will use the Job ID as the name. |
| 140 | + * @param entity |
| 141 | + * The entity to use for the run. This allows you to set the team or username of the WandB user that you would like associated with |
| 142 | + * the run. If not set, the default entity for the registered WandB API key is used. |
| 143 | + * @param tags |
| 144 | + * A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some default tags are |
| 145 | + * generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". |
| 146 | + */ |
| 147 | + case class Wandb( |
| 148 | + project: String, |
| 149 | + name: Option[String] = None, |
| 150 | + entity: Option[String] = None, |
| 151 | + tags: Option[Seq[String]] |
| 152 | + ) |
| 153 | + |
| 154 | + /** @param `type` |
| 155 | + * The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported. |
| 156 | + * @param wandb |
| 157 | + * The settings for your integration with Weights and Biases. This payload specifies the project that metrics will be sent to. |
| 158 | + * Optionally, you can set an explicit display name for your run, add tags to your run, and set a default entity (team, username, |
| 159 | + * etc) to be associated with your run. |
| 160 | + */ |
| 161 | + case class Integration( |
| 162 | + `type`: Type, |
| 163 | + wandb: Wandb |
| 164 | + ) |
| 165 | + } |
| 166 | + |
| 167 | + sealed abstract class FineTuningModel(val value: String) |
| 168 | + |
| 169 | + object FineTuningModel { |
| 170 | + |
| 171 | + implicit val fineTuningModelRW: SnakePickle.ReadWriter[FineTuningModel] = SnakePickle |
| 172 | + .readwriter[ujson.Value] |
| 173 | + .bimap[FineTuningModel]( |
| 174 | + model => SnakePickle.writeJs(model.value), |
| 175 | + jsonValue => |
| 176 | + SnakePickle.read[ujson.Value](jsonValue) match { |
| 177 | + case Str(value) => |
| 178 | + byFineTuningModelValue.getOrElse(value, CustomFineTuningModel(value)) |
| 179 | + case e => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $e")) |
| 180 | + } |
| 181 | + ) |
| 182 | + |
| 183 | + case object GPT4o20240806 extends FineTuningModel("gpt-4o-2024-08-06") |
| 184 | + |
| 185 | + case object GPT4oMini20240718 extends FineTuningModel("gpt-4o-mini-2024-07-18") |
| 186 | + |
| 187 | + case object GPT40613 extends FineTuningModel("gpt-4-0613") |
| 188 | + |
| 189 | + case object GPT35Turbo0125 extends FineTuningModel("gpt-3.5-turbo-0125") |
| 190 | + |
| 191 | + case object GPT35Turbo1106 extends FineTuningModel("gpt-3.5-turbo-1106") |
| 192 | + |
| 193 | + case object GPT35Turbo0613 extends FineTuningModel("gpt-3.5-turbo-0613") |
| 194 | + |
| 195 | + case class CustomFineTuningModel(customFineTuningModel: String) extends FineTuningModel(customFineTuningModel) |
| 196 | + |
| 197 | + val values: Set[FineTuningModel] = Set(GPT4o20240806, GPT4oMini20240718, GPT40613, GPT35Turbo0125, GPT35Turbo1106, GPT35Turbo0613) |
| 198 | + |
| 199 | + private val byFineTuningModelValue = values.map(model => model.value -> model).toMap |
| 200 | + |
| 201 | + } |
| 202 | + |
| 203 | +} |
0 commit comments