Skip to content

Commit 57419ad

Browse files
zygiert1990michal.zyga
and
michal.zyga
authored
Add fine tuning api (#283)
Co-authored-by: michal.zyga <michal.zyga@softwaremill.com>
1 parent 4744078 commit 57419ad

11 files changed

+1071
-1
lines changed

core/src/main/scala/sttp/openai/OpenAI.scala

+103-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatRespons
2424
import sttp.openai.requests.embeddings.EmbeddingsRequestBody.EmbeddingsBody
2525
import sttp.openai.requests.embeddings.EmbeddingsResponseBody.EmbeddingResponse
2626
import sttp.openai.requests.files.FilesResponseData._
27+
import sttp.openai.requests.finetuning
28+
import sttp.openai.requests.finetuning._
2729
import sttp.openai.requests.images.ImageResponseData.ImageResponse
2830
import sttp.openai.requests.images.creation.ImageCreationRequestBody.ImageCreationBody
2931
import sttp.openai.requests.images.edit.ImageEditsConfig
@@ -82,7 +84,7 @@ class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
8284
* @param completionBody
8385
* Create completion request body.
8486
* @deprecated
85-
* This is marked as Legacy in OpenAI API and might be removed in the future. Please use createChatCompletion instead.
87+
* This is marked as Legacy in OpenAI API and might be removed in the future. Please use [[createChatCompletion]] instead.
8688
*/
8789
def createCompletion(completionBody: CompletionsBody): Request[Either[OpenAIException, CompletionsResponse]] =
8890
openAIAuthRequest
@@ -541,6 +543,100 @@ class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
541543
}
542544
.response(asJson_parseErrors[AudioResponse])
543545

546+
/** Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
547+
*
548+
* Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
549+
*
550+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/create]]
551+
*
552+
* @param fineTuningRequestBody
553+
* Request body that will be used to create a fine-tuning job.
554+
*/
555+
def createFineTuningJob(fineTuningRequestBody: FineTuningJobRequestBody): Request[Either[OpenAIException, FineTuningJobResponse]] =
556+
openAIAuthRequest
557+
.post(openAIUris.FineTuningJobs)
558+
.body(fineTuningRequestBody)
559+
.response(asJson_parseErrors[FineTuningJobResponse])
560+
561+
/** List your organization's fine-tuning jobs
562+
*
563+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list]]
564+
*/
565+
def listFineTuningJobs(
566+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
567+
): Request[Either[OpenAIException, ListFineTuningJobResponse]] = {
568+
val uri = openAIUris.FineTuningJobs
569+
.withParams(queryParameters.toMap)
570+
571+
openAIAuthRequest
572+
.get(uri)
573+
.response(asJson_parseErrors[ListFineTuningJobResponse])
574+
}
575+
576+
/** Get status updates for a fine-tuning job.
577+
*
578+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-events]]
579+
*
580+
* @param fineTuningJobId
581+
* The ID of the fine-tuning job to get checkpoints for.
582+
*/
583+
def listFineTuningJobEvents(
584+
fineTuningJobId: String,
585+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
586+
): Request[Either[OpenAIException, ListFineTuningJobEventResponse]] = {
587+
val uri = openAIUris
588+
.fineTuningJobEvents(fineTuningJobId)
589+
.withParams(queryParameters.toMap)
590+
591+
openAIAuthRequest
592+
.get(uri)
593+
.response(asJson_parseErrors[ListFineTuningJobEventResponse])
594+
}
595+
596+
/** List checkpoints for a fine-tuning job.
597+
*
598+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-checkpoints]]
599+
*
600+
* @param fineTuningJobId
601+
* The ID of the fine-tuning job to get checkpoints for.
602+
*/
603+
def listFineTuningJobCheckpoints(
604+
fineTuningJobId: String,
605+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
606+
): Request[Either[OpenAIException, ListFineTuningJobCheckpointResponse]] = {
607+
val uri = openAIUris
608+
.fineTuningJobCheckpoints(fineTuningJobId)
609+
.withParams(queryParameters.toMap)
610+
611+
openAIAuthRequest
612+
.get(uri)
613+
.response(asJson_parseErrors[ListFineTuningJobCheckpointResponse])
614+
}
615+
616+
/** Get info about a fine-tuning job.
617+
*
618+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/retrieve]]
619+
*
620+
* @param fineTuningJobId
621+
* The ID of the fine-tuning job.
622+
*/
623+
def retrieveFineTuningJob(fineTuningJobId: String): Request[Either[OpenAIException, FineTuningJobResponse]] =
624+
openAIAuthRequest
625+
.get(openAIUris.fineTuningJob(fineTuningJobId))
626+
.response(asJson_parseErrors[FineTuningJobResponse])
627+
628+
/** Immediately cancel a fine-tune job.
629+
*
630+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/cancel]]
631+
*
632+
* @param fineTuningJobId
633+
* The ID of the fine-tuning job to cancel.
634+
*/
635+
def cancelFineTuningJob(fineTuningJobId: String): Request[Either[OpenAIException, FineTuningJobResponse]] =
636+
openAIAuthRequest
637+
.post(openAIUris.cancelFineTuningJob(fineTuningJobId))
638+
.response(asJson_parseErrors[FineTuningJobResponse])
639+
544640
/** Gets info about the fine-tune job.
545641
*
546642
* [[https://platform.openai.com/docs/api-reference/embeddings/create]]
@@ -1036,6 +1132,7 @@ private class OpenAIUris(val baseUri: Uri) {
10361132
val Files: Uri = uri"$baseUri/files"
10371133
val Models: Uri = uri"$baseUri/models"
10381134
val Moderations: Uri = uri"$baseUri/moderations"
1135+
val FineTuningJobs: Uri = uri"$baseUri/fine_tuning/jobs"
10391136
val Transcriptions: Uri = audioBase.addPath("transcriptions")
10401137
val Translations: Uri = audioBase.addPath("translations")
10411138
val VariationsImage: Uri = imageBase.addPath("variations")
@@ -1045,6 +1142,11 @@ private class OpenAIUris(val baseUri: Uri) {
10451142
val ThreadsRuns: Uri = uri"$baseUri/threads/runs"
10461143
val VectorStores: Uri = uri"$baseUri/vector_stores"
10471144

1145+
def fineTuningJob(fineTuningJobId: String): Uri = FineTuningJobs.addPath(fineTuningJobId)
1146+
def fineTuningJobEvents(fineTuningJobId: String): Uri = fineTuningJob(fineTuningJobId).addPath("events")
1147+
def fineTuningJobCheckpoints(fineTuningJobId: String): Uri = fineTuningJob(fineTuningJobId).addPath("checkpoints")
1148+
def cancelFineTuningJob(fineTuningJobId: String): Uri = fineTuningJob(fineTuningJobId).addPath("cancel")
1149+
10481150
def file(fileId: String): Uri = Files.addPath(fileId)
10491151
def fileContent(fileId: String): Uri = Files.addPath(fileId, "content")
10501152
def model(modelId: String): Uri = Models.addPath(modelId)

core/src/main/scala/sttp/openai/OpenAISyncClient.scala

+67
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatRespons
1616
import sttp.openai.requests.embeddings.EmbeddingsRequestBody.EmbeddingsBody
1717
import sttp.openai.requests.embeddings.EmbeddingsResponseBody.EmbeddingResponse
1818
import sttp.openai.requests.files.FilesResponseData.{DeletedFileData, FileData, FilesResponse}
19+
import sttp.openai.requests.finetuning
20+
import sttp.openai.requests.finetuning._
1921
import sttp.openai.requests.images.ImageResponseData.ImageResponse
2022
import sttp.openai.requests.images.creation.ImageCreationRequestBody.ImageCreationBody
2123
import sttp.openai.requests.images.edit.ImageEditsConfig
@@ -348,6 +350,71 @@ class OpenAISyncClient private (
348350
def createTranscription(transcriptionConfig: TranscriptionConfig): AudioResponse =
349351
sendOrThrow(openAI.createTranscription(transcriptionConfig))
350352

353+
/** Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
354+
*
355+
* Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
356+
*
357+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/create]]
358+
*
359+
* @param fineTuningRequestBody
360+
* Request body that will be used to create a fine-tuning job.
361+
*/
362+
def createFineTuningJob(fineTuningRequestBody: FineTuningJobRequestBody): FineTuningJobResponse =
363+
sendOrThrow(openAI.createFineTuningJob(fineTuningRequestBody))
364+
365+
/** List your organization's fine-tuning jobs
366+
*
367+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list]]
368+
*/
369+
def listFineTuningJobs(queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty): ListFineTuningJobResponse =
370+
sendOrThrow(openAI.listFineTuningJobs(queryParameters))
371+
372+
/** Get status updates for a fine-tuning job.
373+
*
374+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-events]]
375+
*
376+
* @param fineTuningJobId
377+
* The ID of the fine-tuning job to get checkpoints for.
378+
*/
379+
def listFineTuningJobEvents(
380+
fineTuningJobId: String,
381+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
382+
): ListFineTuningJobEventResponse =
383+
sendOrThrow(openAI.listFineTuningJobEvents(fineTuningJobId, queryParameters))
384+
385+
/** List checkpoints for a fine-tuning job.
386+
*
387+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-checkpoints]]
388+
*
389+
* @param fineTuningJobId
390+
* The ID of the fine-tuning job to get checkpoints for.
391+
*/
392+
def listFineTuningJobCheckpoints(
393+
fineTuningJobId: String,
394+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
395+
): ListFineTuningJobCheckpointResponse =
396+
sendOrThrow(openAI.listFineTuningJobCheckpoints(fineTuningJobId, queryParameters))
397+
398+
/** Get info about a fine-tuning job.
399+
*
400+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/retrieve]]
401+
*
402+
* @param fineTuningJobId
403+
* The ID of the fine-tuning job.
404+
*/
405+
def retrieveFineTuningJob(fineTuningJobId: String): FineTuningJobResponse =
406+
sendOrThrow(openAI.retrieveFineTuningJob(fineTuningJobId))
407+
408+
/** Immediately cancel a fine-tune job.
409+
*
410+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/cancel]]
411+
*
412+
* @param fineTuningJobId
413+
* The ID of the fine-tuning job to cancel.
414+
*/
415+
def cancelFineTuningJob(fineTuningJobId: String): FineTuningJobResponse =
416+
sendOrThrow(openAI.cancelFineTuningJob(fineTuningJobId))
417+
351418
/** Gets info about the fine-tune job.
352419
*
353420
* [[https://platform.openai.com/docs/api-reference/embeddings/create]]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package sttp.openai.requests.finetuning
2+
3+
import sttp.openai.OpenAIExceptions.OpenAIException.DeserializationOpenAIException
4+
import sttp.openai.json.SnakePickle
5+
import ujson.Str
6+
7+
/** @param model
8+
* The name of the model to fine-tune. You can select one of the supported models
9+
* [[https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned]].
10+
* @param trainingFile
11+
* The ID of an uploaded file that contains training data. See upload file for how to upload a file. Your dataset must be formatted as a
12+
* JSONL file. Additionally, you must upload your file with the purpose fine-tune. The contents of the file should differ depending on if
13+
* the model uses the chat, completions format, or if the fine-tuning method uses the preference format. See the fine-tuning guide for
14+
* more details.
15+
* @param suffix
16+
* A string of up to 64 characters that will be added to your fine-tuned model name. For example, a suffix of "custom-model-name" would
17+
* produce a model name like ft:gpt-4o-mini:openai:custom-model-name:7p4lURel.
18+
* @param validationFile
19+
* The ID of an uploaded file that contains validation data. If you provide this file, the data is used to generate validation metrics
20+
* periodically during fine-tuning. These metrics can be viewed in the fine-tuning results file. The same data should not be present in
21+
* both train and validation files. Your dataset must be formatted as a JSONL file. You must upload your file with the purpose fine-tune.
22+
* See the fine-tuning guide for more details.
23+
* @param integrations
24+
* A list of integrations to enable for your fine-tuning job.
25+
* @param seed
26+
* The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may
27+
* differ in rare cases. If a seed is not specified, one will be generated for you.
28+
* @param method
29+
* The method used for fine-tuning.
30+
*/
31+
case class FineTuningJobRequestBody(
32+
model: FineTuningModel,
33+
trainingFile: String,
34+
suffix: Option[String] = None,
35+
validationFile: Option[String] = None,
36+
integrations: Option[Seq[Integration]] = None,
37+
seed: Option[Int] = None,
38+
method: Option[Method] = None
39+
)
40+
object FineTuningJobRequestBody {
41+
implicit val fineTuningRequestBodyWriter: SnakePickle.Writer[FineTuningJobRequestBody] = SnakePickle.macroW[FineTuningJobRequestBody]
42+
}
43+
44+
sealed abstract class FineTuningModel(val value: String)
45+
46+
object FineTuningModel {
47+
48+
implicit val fineTuningModelRW: SnakePickle.ReadWriter[FineTuningModel] = SnakePickle
49+
.readwriter[ujson.Value]
50+
.bimap[FineTuningModel](
51+
model => SnakePickle.writeJs(model.value),
52+
jsonValue =>
53+
SnakePickle.read[ujson.Value](jsonValue) match {
54+
case Str(value) =>
55+
byFineTuningModelValue.getOrElse(value, CustomFineTuningModel(value))
56+
case e => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $e"))
57+
}
58+
)
59+
60+
case object GPT4o20240806 extends FineTuningModel("gpt-4o-2024-08-06")
61+
62+
case object GPT4oMini20240718 extends FineTuningModel("gpt-4o-mini-2024-07-18")
63+
64+
case object GPT40613 extends FineTuningModel("gpt-4-0613")
65+
66+
case object GPT35Turbo0125 extends FineTuningModel("gpt-3.5-turbo-0125")
67+
68+
case object GPT35Turbo1106 extends FineTuningModel("gpt-3.5-turbo-1106")
69+
70+
case object GPT35Turbo0613 extends FineTuningModel("gpt-3.5-turbo-0613")
71+
72+
case class CustomFineTuningModel(customFineTuningModel: String) extends FineTuningModel(customFineTuningModel)
73+
74+
val values: Set[FineTuningModel] = Set(GPT4o20240806, GPT4oMini20240718, GPT40613, GPT35Turbo0125, GPT35Turbo1106, GPT35Turbo0613)
75+
76+
private val byFineTuningModelValue = values.map(model => model.value -> model).toMap
77+
78+
}

0 commit comments

Comments
 (0)