Skip to content

Commit 55c23d0

Browse files
author
michal.zyga
committed
add list fine tuning job checkpoints method
1 parent 157a868 commit 55c23d0

File tree

5 files changed

+164
-8
lines changed

5 files changed

+164
-8
lines changed

core/src/main/scala/sttp/openai/OpenAI.scala

+24
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,9 @@ class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
576576
/** Get status updates for a fine-tuning job.
577577
*
578578
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-events]]
579+
*
580+
* @param fineTuningJobId
581+
* The ID of the fine-tuning job to get checkpoints for.
579582
*/
580583
def listFineTuningJobEvents(
581584
fineTuningJobId: String,
@@ -590,6 +593,26 @@ class OpenAI(authToken: String, baseUri: Uri = OpenAIUris.OpenAIBaseUri) {
590593
.response(asJson_parseErrors[ListFineTuningJobEventResponse])
591594
}
592595

596+
/** List checkpoints for a fine-tuning job.
597+
*
598+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-checkpoints]]
599+
*
600+
* @param fineTuningJobId
601+
* The ID of the fine-tuning job to get checkpoints for.
602+
*/
603+
def listFineTuningJobCheckpoints(
604+
fineTuningJobId: String,
605+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
606+
): Request[Either[OpenAIException, ListFineTuningJobCheckpointResponse]] = {
607+
val uri = openAIUris
608+
.fineTuningJobCheckpoints(fineTuningJobId)
609+
.withParams(queryParameters.toMap)
610+
611+
openAIAuthRequest
612+
.get(uri)
613+
.response(asJson_parseErrors[ListFineTuningJobCheckpointResponse])
614+
}
615+
593616
/** Gets info about the fine-tune job.
594617
*
595618
* [[https://platform.openai.com/docs/api-reference/embeddings/create]]
@@ -1097,6 +1120,7 @@ private class OpenAIUris(val baseUri: Uri) {
10971120

10981121
def fineTuningJob(fineTuningJobId: String): Uri = FineTuningJobs.addPath(fineTuningJobId)
10991122
def fineTuningJobEvents(fineTuningJobId: String): Uri = fineTuningJob(fineTuningJobId).addPath("events")
1123+
def fineTuningJobCheckpoints(fineTuningJobId: String): Uri = fineTuningJob(fineTuningJobId).addPath("checkpoints")
11001124

11011125
def file(fileId: String): Uri = Files.addPath(fileId)
11021126
def fileContent(fileId: String): Uri = Files.addPath(fileId, "content")

core/src/main/scala/sttp/openai/OpenAISyncClient.scala

+17-6
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@ import sttp.openai.requests.embeddings.EmbeddingsRequestBody.EmbeddingsBody
1717
import sttp.openai.requests.embeddings.EmbeddingsResponseBody.EmbeddingResponse
1818
import sttp.openai.requests.files.FilesResponseData.{DeletedFileData, FileData, FilesResponse}
1919
import sttp.openai.requests.finetuning
20-
import sttp.openai.requests.finetuning.{
21-
FineTuningJobRequestBody,
22-
FineTuningJobResponse,
23-
ListFineTuningJobEventResponse,
24-
ListFineTuningJobResponse
25-
}
20+
import sttp.openai.requests.finetuning._
2621
import sttp.openai.requests.images.ImageResponseData.ImageResponse
2722
import sttp.openai.requests.images.creation.ImageCreationRequestBody.ImageCreationBody
2823
import sttp.openai.requests.images.edit.ImageEditsConfig
@@ -377,13 +372,29 @@ class OpenAISyncClient private (
377372
/** Get status updates for a fine-tuning job.
378373
*
379374
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-events]]
375+
*
376+
* @param fineTuningJobId
377+
* The ID of the fine-tuning job to get checkpoints for.
380378
*/
381379
def listFineTuningJobEvents(
382380
fineTuningJobId: String,
383381
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
384382
): ListFineTuningJobEventResponse =
385383
sendOrThrow(openAI.listFineTuningJobEvents(fineTuningJobId, queryParameters))
386384

385+
/** List checkpoints for a fine-tuning job.
386+
*
387+
* [[https://platform.openai.com/docs/api-reference/fine-tuning/list-checkpoints]]
388+
*
389+
* @param fineTuningJobId
390+
* The ID of the fine-tuning job to get checkpoints for.
391+
*/
392+
def listFineTuningJobCheckpoints(
393+
fineTuningJobId: String,
394+
queryParameters: finetuning.QueryParameters = finetuning.QueryParameters.empty
395+
): ListFineTuningJobCheckpointResponse =
396+
sendOrThrow(openAI.listFineTuningJobCheckpoints(fineTuningJobId, queryParameters))
397+
387398
/** Gets info about the fine-tune job.
388399
*
389400
* [[https://platform.openai.com/docs/api-reference/embeddings/create]]

core/src/main/scala/sttp/openai/requests/finetuning/FineTuningJobResponse.scala

+65-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import sttp.openai.OpenAIExceptions.OpenAIException.DeserializationOpenAIExcepti
44
import sttp.openai.json.SnakePickle
55
import ujson.Str
66

7-
/** @param id
7+
/** The fine_tuning.job object represents a fine-tuning job that has been created through the API.
8+
*
9+
* @param id
810
* The object identifier, which can be referenced in the API endpoints.
911
* @param createdAt
1012
* The Unix timestamp (in seconds) for when the fine-tuning job was created.
@@ -129,7 +131,9 @@ object ListFineTuningJobResponse {
129131
implicit val listFineTuningResponseR: SnakePickle.Reader[ListFineTuningJobResponse] = SnakePickle.macroR[ListFineTuningJobResponse]
130132
}
131133

132-
/** @param `object`
134+
/** Fine-tuning job event object
135+
*
136+
* @param `object`
133137
* The object type, which is always "fine_tuning.job.event".
134138
* @param id
135139
* The object identifier.
@@ -168,3 +172,62 @@ object ListFineTuningJobEventResponse {
168172
implicit val listFineTuningJobEventResponseR: SnakePickle.Reader[ListFineTuningJobEventResponse] =
169173
SnakePickle.macroR[ListFineTuningJobEventResponse]
170174
}
175+
176+
/** The fine_tuning.job.checkpoint object represents a model checkpoint for a fine-tuning job that is ready to use.
177+
*
178+
* @param id
179+
* The checkpoint identifier, which can be referenced in the API endpoints.
180+
* @param createdAt
181+
* The Unix timestamp (in seconds) for when the checkpoint was created.
182+
* @param fineTunedModelCheckpoint
183+
* The name of the fine-tuned checkpoint model that is created.
184+
* @param stepNumber
185+
* The step number that the checkpoint was created at.
186+
* @param metrics
187+
* Metrics at the step number during the fine-tuning job.
188+
* @param fineTuningJobId
189+
* The name of the fine-tuning job that this checkpoint was created from.
190+
* @param `object`
191+
* The object type, which is always "fine_tuning.job.checkpoint".
192+
*/
193+
case class FineTuningJobCheckpointResponse(
194+
id: String,
195+
createdAt: Int,
196+
fineTunedModelCheckpoint: String,
197+
stepNumber: Int,
198+
metrics: Metrics,
199+
fineTuningJobId: String,
200+
`object`: String = "fine_tuning.job.checkpoint"
201+
)
202+
203+
object FineTuningJobCheckpointResponse {
204+
implicit val fineTuningJobCheckpointResponseR: SnakePickle.Reader[FineTuningJobCheckpointResponse] =
205+
SnakePickle.macroR[FineTuningJobCheckpointResponse]
206+
}
207+
208+
case class ListFineTuningJobCheckpointResponse(
209+
`object`: String = "list",
210+
data: Seq[FineTuningJobCheckpointResponse],
211+
firstId: String,
212+
lastId: String,
213+
hasMore: Boolean
214+
)
215+
216+
object ListFineTuningJobCheckpointResponse {
217+
implicit val listFineTuningJobCheckpointResponseR: SnakePickle.Reader[ListFineTuningJobCheckpointResponse] =
218+
SnakePickle.macroR[ListFineTuningJobCheckpointResponse]
219+
}
220+
221+
case class Metrics(
222+
step: Float,
223+
trainLoss: Float,
224+
trainMeanTokenAccuracy: Float,
225+
validLoss: Float,
226+
validMeanTokenAccuracy: Float,
227+
fullValidLoss: Float,
228+
fullValidMeanTokenAccuracy: Float
229+
)
230+
231+
object Metrics {
232+
implicit val metricsR: SnakePickle.Reader[Metrics] = SnakePickle.macroR[Metrics]
233+
}

core/src/test/scala/sttp/openai/fixtures/FineTuningJobFixture.scala

+26
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,32 @@ object FineTuningJobFixture {
149149
| "has_more": true
150150
|}""".stripMargin
151151

152+
val jsonListFineTuningJobCheckpointsResponse: String = """{
153+
| "object": "list",
154+
| "data": [
155+
| {
156+
| "object": "fine_tuning.job.checkpoint",
157+
| "id": "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB",
158+
| "created_at": 1721764867,
159+
| "fine_tuned_model_checkpoint": "ft:gpt-4o-mini-2024-07-18:my-org:custom-suffix:96olL566:ckpt-step-2000",
160+
| "metrics": {
161+
| "full_valid_loss": 0.134,
162+
| "full_valid_mean_token_accuracy": 0.874,
163+
| "step": 0.123,
164+
| "train_loss": 0.346,
165+
| "train_mean_token_accuracy": 0.736,
166+
| "valid_loss": 0.654,
167+
| "valid_mean_token_accuracy": 0.738
168+
| },
169+
| "fine_tuning_job_id": "ftjob-abc123",
170+
| "step_number": 2000
171+
| }
172+
| ],
173+
| "first_id": "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB",
174+
| "last_id": "ftckpt_enQCFmOTGj3syEpYVhBRLTSy",
175+
| "has_more": true
176+
|}""".stripMargin
177+
152178
val fineTuningJobResponse: FineTuningJobResponse = FineTuningJobResponse(
153179
id = "ft-id",
154180
createdAt = 1000,

core/src/test/scala/sttp/openai/requests/finetuning/FineTuningDataSpec.scala

+32
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,39 @@ class FineTuningDataSpec extends AnyFlatSpec with Matchers with EitherValues {
131131
// when
132132
val deserializedJsonResponse: Either[Exception, ListFineTuningJobEventResponse] =
133133
SttpUpickleApiExtension.deserializeJsonSnake[ListFineTuningJobEventResponse].apply(jsonResponse)
134+
// then
135+
deserializedJsonResponse.value shouldBe expectedResponse
136+
}
134137

138+
"Given list fine tuning job checkpoints response as Json" should "be properly deserialized to case class" in {
139+
// given
140+
val jsonResponse = FineTuningJobFixture.jsonListFineTuningJobCheckpointsResponse
141+
val expectedResponse: ListFineTuningJobCheckpointResponse = ListFineTuningJobCheckpointResponse(
142+
data = Seq(
143+
FineTuningJobCheckpointResponse(
144+
id = "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB",
145+
createdAt = 1721764867,
146+
fineTunedModelCheckpoint = "ft:gpt-4o-mini-2024-07-18:my-org:custom-suffix:96olL566:ckpt-step-2000",
147+
metrics = Metrics(
148+
fullValidLoss = 0.134f,
149+
fullValidMeanTokenAccuracy = 0.874f,
150+
step = 0.123f,
151+
trainLoss = 0.346f,
152+
trainMeanTokenAccuracy = 0.736f,
153+
validLoss = 0.654f,
154+
validMeanTokenAccuracy = 0.738f
155+
),
156+
fineTuningJobId = "ftjob-abc123",
157+
stepNumber = 2000
158+
)
159+
),
160+
firstId = "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB",
161+
lastId = "ftckpt_enQCFmOTGj3syEpYVhBRLTSy",
162+
hasMore = true
163+
)
164+
// when
165+
val deserializedJsonResponse: Either[Exception, ListFineTuningJobCheckpointResponse] =
166+
SttpUpickleApiExtension.deserializeJsonSnake[ListFineTuningJobCheckpointResponse].apply(jsonResponse)
135167
// then
136168
deserializedJsonResponse.value shouldBe expectedResponse
137169
}

0 commit comments

Comments
 (0)