Skip to content

Commit

Permalink
fix for OpenAI to do embeddings in batches
Browse files Browse the repository at this point in the history
  • Loading branch information
TesAnti committed Jan 17, 2024
1 parent c6fa507 commit 44f3b2d
Showing 1 changed file with 36 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Diagnostics;
using OpenAI.Constants;
using OpenAI.Embeddings;
using static System.Net.Mime.MediaTypeNames;

namespace LangChain.Providers.OpenAI;

Expand Down Expand Up @@ -57,14 +58,46 @@ public async Task<float[]> EmbedQueryAsync(
return response.Data[0].Embedding.Select(static x => (float)x).ToArray();
}

const int MaxElementsPerRequest = 2048;
/// <inheritdoc/>
public async Task<float[][]> EmbedDocumentsAsync(
string[] texts,
CancellationToken cancellationToken = default)
{
return await Task.WhenAll(
texts
.Select(text => EmbedQueryAsync(text, cancellationToken))).ConfigureAwait(false);

// API has limit of 2048 elements in array per request
// so we need to split texts into batches
// https://platform.openai.com/docs/api-reference/embeddings


var watch = Stopwatch.StartNew();
List<float[]> result = new();
List<string> allTexts = new(texts);

while (allTexts.Count > 0)
{
var currentBatch = allTexts.Take(MaxElementsPerRequest).ToArray();
allTexts.RemoveRange(0, currentBatch.Length);

var response = await Api.EmbeddingsEndpoint.CreateEmbeddingAsync(
request: new EmbeddingsRequest(
input: currentBatch,
model: EmbeddingModelId,
user: User),
cancellationToken).ConfigureAwait(false);

var usage = GetUsage(response) with
{
Time = watch.Elapsed,
};
lock (_usageLock)
{
TotalUsage += usage;
}
result.AddRange(response.Data.Select(static x => x.Embedding.Select(static x => (float)x).ToArray()));
}

return result.ToArray();
}

#endregion
Expand Down

0 comments on commit 44f3b2d

Please sign in to comment.