diff --git a/src/libs/Providers/LangChain.Providers.OpenAI/OpenAiModel.Embeddings.cs b/src/libs/Providers/LangChain.Providers.OpenAI/OpenAiModel.Embeddings.cs index ced84df6..bfdee323 100644 --- a/src/libs/Providers/LangChain.Providers.OpenAI/OpenAiModel.Embeddings.cs +++ b/src/libs/Providers/LangChain.Providers.OpenAI/OpenAiModel.Embeddings.cs @@ -1,6 +1,7 @@ using System.Diagnostics; using OpenAI.Constants; using OpenAI.Embeddings; +using static System.Net.Mime.MediaTypeNames; namespace LangChain.Providers.OpenAI; @@ -57,14 +58,46 @@ public async Task EmbedQueryAsync( return response.Data[0].Embedding.Select(static x => (float)x).ToArray(); } + const int MaxElementsPerRequest = 2048; /// public async Task EmbedDocumentsAsync( string[] texts, CancellationToken cancellationToken = default) { - return await Task.WhenAll( - texts - .Select(text => EmbedQueryAsync(text, cancellationToken))).ConfigureAwait(false); + + // API has limit of 2048 elements in array per request + // so we need to split texts into batches + // https://platform.openai.com/docs/api-reference/embeddings + + + var watch = Stopwatch.StartNew(); + List result = new(); + List allTexts = new(texts); + + while (allTexts.Count > 0) + { + var currentBatch = allTexts.Take(MaxElementsPerRequest).ToArray(); + allTexts.RemoveRange(0, currentBatch.Length); + + var response = await Api.EmbeddingsEndpoint.CreateEmbeddingAsync( + request: new EmbeddingsRequest( + input: currentBatch, + model: EmbeddingModelId, + user: User), + cancellationToken).ConfigureAwait(false); + + var usage = GetUsage(response) with + { + Time = watch.Elapsed, + }; + lock (_usageLock) + { + TotalUsage += usage; + } + result.AddRange(response.Data.Select(static x => x.Embedding.Select(static x => (float)x).ToArray())); + } + + return result.ToArray(); } #endregion