Skip to content

Commit dc8f74f

Browse files
author
michal.zyga
committed
add FineTuningRequestBody
1 parent 5289841 commit dc8f74f

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package sttp.openai.requests.finetuning
2+
3+
import sttp.openai.OpenAIExceptions.OpenAIException.DeserializationOpenAIException
4+
import sttp.openai.json.SnakePickle
5+
import sttp.openai.requests.finetuning.FineTuningRequestBody.Integration.Integration
6+
import sttp.openai.requests.finetuning.FineTuningRequestBody.Method.Method
7+
import ujson.Str
8+
9+
object FineTuningRequestBody {
10+
11+
/** @param model
12+
* The name of the model to fine-tune. You can select one of the supported models
13+
* [[https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned]].
14+
* @param trainingFile
15+
* The ID of an uploaded file that contains training data. See upload file for how to upload a file. Your dataset must be formatted as
16+
* a JSONL file. Additionally, you must upload your file with the purpose fine-tune. The contents of the file should differ depending
17+
* on if the model uses the chat, completions format, or if the fine-tuning method uses the preference format. See the fine-tuning
18+
* guide for more details.
19+
* @param suffix
20+
* A string of up to 64 characters that will be added to your fine-tuned model name. For example, a suffix of "custom-model-name" would
21+
* produce a model name like ft:gpt-4o-mini:openai:custom-model-name:7p4lURel.
22+
* @param validationFile
23+
* The ID of an uploaded file that contains validation data. If you provide this file, the data is used to generate validation metrics
24+
* periodically during fine-tuning. These metrics can be viewed in the fine-tuning results file. The same data should not be present in
25+
* both train and validation files. Your dataset must be formatted as a JSONL file. You must upload your file with the purpose
26+
* fine-tune. See the fine-tuning guide for more details.
27+
* @param integrations
28+
* A list of integrations to enable for your fine-tuning job.
29+
* @param seed
30+
* The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but
31+
* may differ in rare cases. If a seed is not specified, one will be generated for you.
32+
* @param method
33+
* The method used for fine-tuning.
34+
*/
35+
case class FineTuningRequestBody(
36+
model: FineTuningModel,
37+
trainingFile: String,
38+
suffix: Option[String] = None,
39+
validationFile: Option[String] = None,
40+
integrations: Option[Seq[Integration]] = None,
41+
seed: Option[Int] = None,
42+
method: Option[Method] = None
43+
)
44+
45+
sealed abstract class Type(val value: String)
46+
47+
object Type {
48+
def typeRW(byTypeValue: Map[String, Type]): SnakePickle.ReadWriter[Type] = SnakePickle
49+
.readwriter[ujson.Value]
50+
.bimap[Type](
51+
`type` => SnakePickle.writeJs(`type`.value),
52+
jsonValue =>
53+
SnakePickle.read[ujson.Value](jsonValue) match {
54+
case Str(value) =>
55+
byTypeValue.get(value) match {
56+
case Some(t) => t
57+
case None => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $value"))
58+
}
59+
case e => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $e"))
60+
}
61+
)
62+
}
63+
64+
object Method {
65+
66+
object MethodType {
67+
case object Supervised extends Type("supervised")
68+
69+
case object Dpo extends Type("dpo")
70+
71+
private val values: Set[Type] = Set(Supervised, Dpo)
72+
73+
private val byTypeValue = values.map(`type` => `type`.value -> `type`).toMap
74+
75+
implicit val typeRW: SnakePickle.ReadWriter[Type] = Type.typeRW(byTypeValue)
76+
}
77+
78+
/** @param batchSize
79+
* Number of examples in each batch. A larger batch size means that model parameters are updated less frequently, but with lower
80+
* variance.
81+
* @param learningRateMultiplier
82+
* Scaling factor for the learning rate. A smaller learning rate may be useful to avoid overfitting.
83+
* @param nEpochs
84+
* The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
85+
* @param beta
86+
* The beta value for the DPO method. A higher beta value will increase the weight of the penalty between the policy and reference
87+
* model.
88+
*/
89+
case class Hyperparameters(
90+
batchSize: Option[Int] = None,
91+
learningRateMultiplier: Option[Float] = None,
92+
nEpochs: Option[Int] = None,
93+
beta: Option[Float] = None
94+
)
95+
96+
/** @param hyperparameters
97+
* The hyperparameters used for the fine-tuning job.
98+
*/
99+
case class Supervised(
100+
hyperparameters: Option[Hyperparameters] = None
101+
)
102+
103+
/** @param hyperparameters
104+
* The hyperparameters used for the fine-tuning job.
105+
*/
106+
case class Dpo(
107+
hyperparameters: Option[Hyperparameters] = None
108+
)
109+
110+
/** @param `type`
111+
* The type of method. Is either supervised or dpo.
112+
* @param supervised
113+
* Configuration for the supervised fine-tuning method.
114+
* @param dpo
115+
* Configuration for the DPO fine-tuning method.
116+
*/
117+
case class Method(
118+
`type`: Option[Type] = None,
119+
supervised: Option[Supervised] = None,
120+
dpo: Option[Dpo] = None
121+
)
122+
}
123+
124+
object Integration {
125+
126+
object IntegrationType {
127+
case object Wandb extends Type("wandb")
128+
129+
private val values: Set[Type] = Set(Wandb)
130+
131+
private val byTypeValue = values.map(`type` => `type`.value -> `type`).toMap
132+
133+
implicit val typeRW: SnakePickle.ReadWriter[Type] = Type.typeRW(byTypeValue)
134+
}
135+
136+
/** @param project
137+
* The name of the project that the new run will be created under.
138+
* @param name
139+
* A display name to set for the run. If not set, we will use the Job ID as the name.
140+
* @param entity
141+
* The entity to use for the run. This allows you to set the team or username of the WandB user that you would like associated with
142+
* the run. If not set, the default entity for the registered WandB API key is used.
143+
* @param tags
144+
* A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some default tags are
145+
* generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
146+
*/
147+
case class Wandb(
148+
project: String,
149+
name: Option[String] = None,
150+
entity: Option[String] = None,
151+
tags: Option[Seq[String]]
152+
)
153+
154+
/** @param `type`
155+
* The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.
156+
* @param wandb
157+
* The settings for your integration with Weights and Biases. This payload specifies the project that metrics will be sent to.
158+
* Optionally, you can set an explicit display name for your run, add tags to your run, and set a default entity (team, username,
159+
* etc) to be associated with your run.
160+
*/
161+
case class Integration(
162+
`type`: Type,
163+
wandb: Wandb
164+
)
165+
}
166+
167+
sealed abstract class FineTuningModel(val value: String)
168+
169+
object FineTuningModel {
170+
171+
implicit val fineTuningModelRW: SnakePickle.ReadWriter[FineTuningModel] = SnakePickle
172+
.readwriter[ujson.Value]
173+
.bimap[FineTuningModel](
174+
model => SnakePickle.writeJs(model.value),
175+
jsonValue =>
176+
SnakePickle.read[ujson.Value](jsonValue) match {
177+
case Str(value) =>
178+
byFineTuningModelValue.getOrElse(value, CustomFineTuningModel(value))
179+
case e => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $e"))
180+
}
181+
)
182+
183+
case object GPT4o20240806 extends FineTuningModel("gpt-4o-2024-08-06")
184+
185+
case object GPT4oMini20240718 extends FineTuningModel("gpt-4o-mini-2024-07-18")
186+
187+
case object GPT40613 extends FineTuningModel("gpt-4-0613")
188+
189+
case object GPT35Turbo0125 extends FineTuningModel("gpt-3.5-turbo-0125")
190+
191+
case object GPT35Turbo1106 extends FineTuningModel("gpt-3.5-turbo-1106")
192+
193+
case object GPT35Turbo0613 extends FineTuningModel("gpt-3.5-turbo-0613")
194+
195+
case class CustomFineTuningModel(customFineTuningModel: String) extends FineTuningModel(customFineTuningModel)
196+
197+
val values: Set[FineTuningModel] = Set(GPT4o20240806, GPT4oMini20240718, GPT40613, GPT35Turbo0125, GPT35Turbo1106, GPT35Turbo0613)
198+
199+
private val byFineTuningModelValue = values.map(model => model.value -> model).toMap
200+
201+
}
202+
203+
}

0 commit comments

Comments
 (0)