Skip to content

Commit 634f6cd

Browse files
author
michal.zyga
committed
add create fine-tuning and list fine-tuning to sync client
1 parent d2591f8 commit 634f6cd

File tree

5 files changed

+147
-21
lines changed

5 files changed

+147
-21
lines changed

core/src/main/scala/sttp/openai/OpenAI.scala

+33
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatRespons
2424
import sttp.openai.requests.embeddings.EmbeddingsRequestBody.EmbeddingsBody
2525
import sttp.openai.requests.embeddings.EmbeddingsResponseBody.EmbeddingResponse
2626
import sttp.openai.requests.files.FilesResponseData._
27+
import sttp.openai.requests.finetuning
28+
import sttp.openai.requests.finetuning._
2729
import sttp.openai.requests.images.ImageResponseData.ImageResponse
2830
import sttp.openai.requests.images.creation.ImageCreationRequestBody.ImageCreationBody
2931
import sttp.openai.requests.images.edit.ImageEditsConfig
@@ -541,6 +543,36 @@ class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
541543
}
542544
.response(asJson_parseErrors[AudioResponse])
543545

546+
/** Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
547+
*
548+
* Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
549+
*
550+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/create]]
551+
*
552+
* @param fineTuningRequestBody
553+
* Request body that will be used to create a fine-tuning job.
554+
*/
555+
def createFineTuningJob(fineTuningRequestBody: FineTuningRequestBody): Request[Either[OpenAIException, FineTuningResponse]] =
556+
openAIAuthRequest
557+
.post(openAIUris.FineTuning)
558+
.body(fineTuningRequestBody)
559+
.response(asJson_parseErrors[FineTuningResponse])
560+
561+
/** List your organization's fine-tuning jobs
562+
*
563+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list]]
564+
*/
565+
def listFineTuningJobs(
566+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
567+
): Request[Either[OpenAIException, ListFineTuningResponse]] = {
568+
val uri = openAIUris.FineTuning
569+
.withParams(queryParameters.toMap)
570+
571+
openAIAuthRequest
572+
.get(uri)
573+
.response(asJson_parseErrors[ListFineTuningResponse])
574+
}
575+
544576
/** Gets info about the fine-tune job.
545577
*
546578
* [[https://platform.openai.com/docs/api-reference/embeddings/create]]
@@ -1036,6 +1068,7 @@ private class OpenAIUris(val baseUri: Uri) {
10361068
val Files: Uri = uri"$baseUri/files"
10371069
val Models: Uri = uri"$baseUri/models"
10381070
val Moderations: Uri = uri"$baseUri/moderations"
1071+
val FineTuning: Uri = uri"$baseUri/fine_tuning/jobs"
10391072
val Transcriptions: Uri = audioBase.addPath("transcriptions")
10401073
val Translations: Uri = audioBase.addPath("translations")
10411074
val VariationsImage: Uri = imageBase.addPath("variations")

core/src/main/scala/sttp/openai/OpenAISyncClient.scala

+21
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatRespons
1616
import sttp.openai.requests.embeddings.EmbeddingsRequestBody.EmbeddingsBody
1717
import sttp.openai.requests.embeddings.EmbeddingsResponseBody.EmbeddingResponse
1818
import sttp.openai.requests.files.FilesResponseData.{DeletedFileData, FileData, FilesResponse}
19+
import sttp.openai.requests.finetuning
20+
import sttp.openai.requests.finetuning.{FineTuningRequestBody, FineTuningResponse, ListFineTuningResponse}
1921
import sttp.openai.requests.images.ImageResponseData.ImageResponse
2022
import sttp.openai.requests.images.creation.ImageCreationRequestBody.ImageCreationBody
2123
import sttp.openai.requests.images.edit.ImageEditsConfig
@@ -348,6 +350,25 @@ class OpenAISyncClient private (
348350
def createTranscription(transcriptionConfig: TranscriptionConfig): AudioResponse =
349351
sendOrThrow(openAI.createTranscription(transcriptionConfig))
350352

353+
/** Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
354+
*
355+
* Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
356+
*
357+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/create]]
358+
*
359+
* @param fineTuningRequestBody
360+
* Request body that will be used to create a fine-tuning job.
361+
*/
362+
def createFineTuningJob(fineTuningRequestBody: FineTuningRequestBody): FineTuningResponse =
363+
sendOrThrow(openAI.createFineTuningJob(fineTuningRequestBody))
364+
365+
/** List your organization's fine-tuning jobs
366+
*
367+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list]]
368+
*/
369+
def listFineTuningJobs(queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty): ListFineTuningResponse =
370+
sendOrThrow(openAI.listFineTuningJobs(queryParameters))
371+
351372
/** Gets info about the fine-tune job.
352373
*
353374
* [[https://platform.openai.com/docs/api-reference/embeddings/create]]

core/src/main/scala/sttp/openai/requests/finetuning/FineTuningResponse.scala

+54-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,45 @@ import sttp.openai.OpenAIExceptions.OpenAIException.DeserializationOpenAIExcepti
44
import sttp.openai.json.SnakePickle
55
import ujson.Str
66

7+
/** @param id
8+
* The object identifier, which can be referenced in the API endpoints.
9+
* @param createdAt
10+
* The Unix timestamp (in seconds) for when the fine-tuning job was created.
11+
* @param error
12+
* For fine-tuning jobs that have failed, this will contain more information on the cause of the failure.
13+
* @param fineTunedModel
14+
* The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running.
15+
* @param finishedAt
16+
* The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still
17+
* running.
18+
* @param hyperparameters
19+
* The hyperparameters used for the fine-tuning job. This value will only be returned when running supervised jobs.
20+
* @param model
21+
* The base model that is being fine-tuned.
22+
* @param `object`
23+
* The object type, which is always "fine_tuning.job".
24+
* @param organizationId
25+
* The organization that owns the fine-tuning job.
26+
* @param resultFiles
27+
* The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the Files API.
28+
* @param status
29+
* The current status of the fine-tuning job, which can be either validating_files, queued, running, succeeded, failed, or cancelled.
30+
* @param trainedTokens
31+
* The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running.
32+
* @param trainingFile
33+
* The file ID used for training. You can retrieve the training data with the Files API.
34+
* @param validationFile
35+
* The file ID used for validation. You can retrieve the validation results with the Files API.
36+
* @param integrations
37+
* A list of integrations to enable for this fine-tuning job.
38+
* @param seed
39+
* The seed used for the fine-tuning job.
40+
* @param estimatedFinish
41+
* The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is
42+
* not running.
43+
* @param method
44+
* The method used for fine-tuning.
45+
*/
746
case class FineTuningResponse(
847
id: String,
948
createdAt: Int,
@@ -53,12 +92,11 @@ object Status {
5392

5493
implicit val statusRW: SnakePickle.Reader[Status] = SnakePickle
5594
.reader[ujson.Value]
56-
.map[Status](
57-
jsonValue =>
58-
SnakePickle.read[ujson.Value](jsonValue) match {
59-
case Str(value) => byStatusValue.getOrElse(value, CustomStatus(value))
60-
case e => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $e"))
61-
}
95+
.map[Status](jsonValue =>
96+
SnakePickle.read[ujson.Value](jsonValue) match {
97+
case Str(value) => byStatusValue.getOrElse(value, CustomStatus(value))
98+
case e => throw DeserializationOpenAIException(new Exception(s"Could not deserialize: $e"))
99+
}
62100
)
63101

64102
case object ValidatingFiles extends Status("validating_files")
@@ -80,3 +118,13 @@ object Status {
80118
private val byStatusValue = values.map(status => status.value -> status).toMap
81119

82120
}
121+
122+
case class ListFineTuningResponse(
123+
`object`: String = "list",
124+
data: Seq[FineTuningResponse],
125+
hasMore: Boolean
126+
)
127+
128+
object ListFineTuningResponse {
129+
implicit val listFineTuningResponseR: SnakePickle.Reader[ListFineTuningResponse] = SnakePickle.macroR[ListFineTuningResponse]
130+
}

core/src/main/scala/sttp/openai/requests/finetuning/Integration.scala

+15-15
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@ object Integration {
2727
}
2828

2929
/** @param project
30-
* The name of the project that the new run will be created under.
31-
* @param name
32-
* A display name to set for the run. If not set, we will use the Job ID as the name.
33-
* @param entity
34-
* The entity to use for the run. This allows you to set the team or username of the WandB user that you would like associated with the
35-
* run. If not set, the default entity for the registered WandB API key is used.
36-
* @param tags
37-
* A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some default tags are
38-
* generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
39-
*/
30+
* The name of the project that the new run will be created under.
31+
* @param name
32+
* A display name to set for the run. If not set, we will use the Job ID as the name.
33+
* @param entity
34+
* The entity to use for the run. This allows you to set the team or username of the WandB user that you would like associated with the
35+
* run. If not set, the default entity for the registered WandB API key is used.
36+
* @param tags
37+
* A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some default tags are
38+
* generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
39+
*/
4040
case class Wandb(
41-
project: String,
42-
name: Option[String] = None,
43-
entity: Option[String] = None,
44-
tags: Option[Seq[String]]
45-
)
41+
project: String,
42+
name: Option[String] = None,
43+
entity: Option[String] = None,
44+
tags: Option[Seq[String]]
45+
)
4646

4747
object Wandb {
4848
implicit val wandbRW: SnakePickle.ReadWriter[Wandb] = SnakePickle.macroRW[Wandb]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package sttp.openai.requests.finetuning
2+
3+
/** @param after
4+
* A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request
5+
* and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the
6+
* list
7+
* @param limit
8+
* Defaults to 20 A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
9+
*/
10+
case class QueryParameters(
11+
after: Option[String] = None,
12+
limit: Option[Int] = None
13+
) {
14+
15+
def toMap: Map[String, String] = {
16+
val queryParams = after.map("after" -> _) ++
17+
limit.map(_.toString).map("order" -> _)
18+
queryParams.toMap
19+
}
20+
}
21+
22+
object QueryParameters {
23+
val empty: QueryParameters = QueryParameters(None, None)
24+
}

0 commit comments

Comments
 (0)