diff --git a/app/api/compatible.go b/app/api/compatible.go index 3074902..abc0f96 100644 --- a/app/api/compatible.go +++ b/app/api/compatible.go @@ -4,7 +4,6 @@ import ( "github.com/alioth-center/akasha-whisper/app/service" "github.com/alioth-center/infrastructure/network/http" "github.com/alioth-center/infrastructure/thirdparty/openai" - "github.com/gin-gonic/gin" ) var CompatibleApi compatibleApiImpl @@ -17,10 +16,6 @@ func (impl compatibleApiImpl) CompleteChat() http.Chain[*openai.CompleteChatRequ return http.NewChain(impl.service.ChatComplete) } -func (impl compatibleApiImpl) StreamingCompleteChat() []gin.HandlerFunc { - return []gin.HandlerFunc{impl.service.StreamingChatCompletion} -} - func (impl compatibleApiImpl) Embedding() http.Chain[*openai.EmbeddingRequestBody, *openai.EmbeddingResponseBody] { return http.NewChain(impl.service.EmbeddingAuthorize, impl.service.Embedding) } diff --git a/app/router/compatible.go b/app/router/compatible.go index 67df718..9281447 100644 --- a/app/router/compatible.go +++ b/app/router/compatible.go @@ -11,7 +11,7 @@ var compatibleRouter = http.NewRouter("v1") var OpenAiCompatibleRouterGroup = []http.EndPointInterface{ http.NewEndPointBuilder[*openai.CompleteChatRequestBody, *openai.CompleteChatResponseBody](). SetNecessaryHeaders("Authorization"). - SetGinMiddlewares(api.CompatibleApi.StreamingCompleteChat()...). + SetCustomRender(true). SetHandlerChain(api.CompatibleApi.CompleteChat()). SetAllowMethods(http.POST). SetRouter(compatibleRouter.Group("/chat/completions")). diff --git a/app/service/compatible.go b/app/service/compatible.go index d11d59c..1f742b7 100644 --- a/app/service/compatible.go +++ b/app/service/compatible.go @@ -1,16 +1,11 @@ package service import ( - "bytes" "context" "encoding/json" - "io" "strings" "time" - "github.com/gin-contrib/sse" - "github.com/gin-gonic/gin" - "github.com/alioth-center/akasha-whisper/app/global" "github.com/alioth-center/akasha-whisper/app/model" "github.com/alioth-center/infrastructure/logger" @@ -18,6 +13,7 @@ import ( "github.com/alioth-center/infrastructure/thirdparty/openai" "github.com/alioth-center/infrastructure/trace" "github.com/alioth-center/infrastructure/utils/values" + "github.com/gin-contrib/sse" "github.com/pkg/errors" "github.com/shopspring/decimal" ) @@ -44,21 +40,22 @@ func (srv *CompatibleService) ChatComplete(ctx http.Context[*openai.CompleteChat ctx.Abort() return } else if getErr != nil { + global.Logger.Error(logger.NewFields(ctx).WithMessage("get available client failed").WithData(getErr)) ctx.SetStatusCode(http.StatusInternalServerError) ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", getErr.Error()))) ctx.Abort() return } - // complete chat - response, executeErr := client.CompleteChat(ctx, openai.CompleteChatRequest{ + realPromptToken, realCompletionToken, requestID := int64(0), int64(0), "" + openaiRequest := openai.CompleteChatRequest{ Body: openai.CompleteChatRequestBody{ Model: request.Model, Messages: request.Messages, Temperature: request.Temperature, TopP: request.TopP, N: request.N, - Stream: false, + Stream: request.Stream, MaxTokens: min(request.MaxTokens, global.Config.App.MaxToken), PresencePenalty: request.PresencePenalty, FrequencyPenalty: request.FrequencyPenalty, @@ -72,148 +69,115 @@ func (srv *CompatibleService) ChatComplete(ctx http.Context[*openai.CompleteChat Tools: request.Tools, ToolChoice: request.ToolChoice, }, - }) - if executeErr != nil { - ctx.SetStatusCode(http.StatusInternalServerError) - ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error()))) - ctx.Abort() - return } - // consume success, update balances - promptCostAmount := metadata.ModelPromptPrice.Mul(decimal.NewFromInt(int64(response.Usage.PromptTokens))).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit)) - completionCostAmount := metadata.ModelCompletionPrice.Mul(decimal.NewFromInt(int64(response.Usage.CompletionTokens))).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit)) - balanceCost := promptCostAmount.Add(completionCostAmount).Mul(decimal.NewFromInt(-1)) - _, updateClientBalanceErr := global.OpenaiClientBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.ClientID, balanceCost, model.OpenaiClientBalanceActionConsumption) - _, updateUserBalanceErr := global.WhisperUserBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.UserID, balanceCost, model.WhisperUserBalanceActionConsumption) - updateRequestErr := global.OpenaiRequestDatabaseInstance.CreateOpenaiRequestRecord(ctx, &model.OpenaiRequest{ - ClientID: int64(metadata.ClientID), - ModelID: int64(metadata.ModelID), - UserID: int64(metadata.UserID), - RequestIP: ctx.ExtraParams().GetString(http.RemoteIPKey), - RequestID: response.ID, - TraceID: trace.GetTid(ctx), - PromptTokenUsage: response.Usage.PromptTokens, - CompletionTokenUsage: response.Usage.CompletionTokens, - BalanceCost: balanceCost.Abs(), - }) - for _, err := range []error{updateClientBalanceErr, updateUserBalanceErr, updateRequestErr} { - if err != nil { - global.Logger.Error(logger.NewFields(ctx).WithMessage("update response result failed").WithData(err)) + if !request.Stream { + // complete chat without text stream + response, executeErr := client.CompleteChat(ctx, openaiRequest) + if executeErr != nil { + global.Logger.Error(logger.NewFields(ctx).WithMessage("complete chat failed").WithData(executeErr)) + ctx.SetStatusCode(http.StatusInternalServerError) + ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error()))) + ctx.Abort() + return } - } - - // return openai response - ctx.SetStatusCode(http.StatusOK) - ctx.SetResponse(&response) -} - -func (srv *CompatibleService) StreamingChatCompletion(ctx *gin.Context) { - // check api key available - apiKey := ctx.GetHeader(http.HeaderAuthorization) - exist, allowIPs, err := CheckApiKeyAvailable(ctx, apiKey) - if err != nil { - global.Logger.Error(logger.NewFields(ctx).WithMessage("check api key available failed").WithData(err)) - ctx.AbortWithStatusJSON(http.StatusInternalServerError, srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", err.Error()))) - return - } - if !exist { - ctx.AbortWithStatusJSON(http.StatusUnauthorized, srv.buildErrorChatCompleteResponse(ctx, "unauthorized")) - return - } - // check allow ip - if !CheckAllowIP(ctx, ctx.ClientIP(), strings.Split(allowIPs, ",")) { - ctx.AbortWithStatusJSON(http.StatusForbidden, srv.buildErrorChatCompleteResponse(ctx, "ip forbidden")) - return - } + realPromptToken, realCompletionToken, requestID = int64(response.Usage.PromptTokens), int64(response.Usage.CompletionTokens), response.ID - request := &openai.CompleteChatRequestBody{} - bindErr := ctx.ShouldBindJSON(request) - if bindErr != nil { - ctx.JSON(http.StatusBadRequest, gin.H{"error": bindErr.Error()}) - } - writeBack, _ := json.Marshal(request) - ctx.Request.Body = io.NopCloser(bytes.NewBuffer(writeBack)) - - if request.Stream { - defer ctx.Abort() - inputMessages := make([]string, len(request.Messages)) - for i, message := range request.Messages { - inputMessages[i] = message.GetStringContent() + // marshal response to json + responseJson, marshalErr := json.Marshal(response) + if marshalErr != nil { + global.Logger.Error(logger.NewFields(ctx).WithMessage("marshal response failed").WithData(marshalErr)) + ctx.SetStatusCode(http.StatusInternalServerError) + ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", marshalErr.Error()))) + ctx.Abort() + return } - promptToken := CalculatePromptToken(inputMessages...) - // get available openai client - client, metadata, getErr := GetAvailableClient(ctx, apiKey, request.Model, promptToken, "chat") - if getErr != nil && errors.Is(getErr, ErrorNoAvailableClient) { - ctx.AbortWithStatusJSON(http.StatusForbidden, srv.buildErrorChatCompleteResponse(ctx, "no available client")) - return - } else if getErr != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", getErr.Error()))) + // set response header + ctx.CustomRender().Header().Set("Cache-Control", "no-cache") + ctx.CustomRender().Header().Set(http.HeaderContentType, http.ContentTypeJson) + ctx.CustomRender().WriteHeaderNow() + + // write response + _, writeErr := ctx.CustomRender().Write(responseJson) + if writeErr != nil { + global.Logger.Error(logger.NewFields(ctx).WithMessage("write response failed").WithData(writeErr)) + ctx.SetStatusCode(http.StatusInternalServerError) + ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", writeErr.Error()))) + ctx.Abort() return } + } else { + // complete chat with text stream + openaiRequest.Body.StreamOptions = json.RawMessage(`{"include_usage": true}`) - request.StreamOptions = json.RawMessage(`{"include_usage": true}`) - response, executeErr := client.CompleteStreamingChat(ctx, openai.CompleteChatRequest{Body: *request}) + response, executeErr := client.CompleteStreamingChat(ctx, openaiRequest) if executeErr != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error()))) + global.Logger.Error(logger.NewFields(ctx).WithMessage("complete streaming chat failed").WithData(executeErr)) + ctx.SetStatusCode(http.StatusInternalServerError) + ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error()))) + ctx.Abort() return } - ctx.Header(http.HeaderContentType, "text/event-stream") - ctx.Header("Cache-Control", "no-cache") - ctx.Header("Transfer-Encoding", "chunked") - ctx.Header("Connection", "keep-alive") - - completionToken, requestID := int64(0), "" + // set response header + ctx.CustomRender().Header().Set(http.HeaderContentType, "text/event-stream") + ctx.CustomRender().Header().Set("Cache-Control", "no-cache") + ctx.CustomRender().Header().Set("Transfer-Encoding", "chunked") + ctx.CustomRender().Header().Set("Connection", "keep-alive") + ctx.CustomRender().WriteHeaderNow() + // parse streaming response for object := range response { - encodeErr := sse.Encode(ctx.Writer, sse.Event{Data: object}) + if object.Usage != nil { + realPromptToken, realCompletionToken, requestID = int64(object.Usage.PromptTokens), int64(object.Usage.CompletionTokens), object.Id + } + + encodeErr := sse.Encode(ctx.CustomRender(), sse.Event{Data: object}) if encodeErr != nil { global.Logger.Error(logger.NewFields(ctx).WithMessage("encode response failed").WithData(encodeErr)) continue } - ctx.Writer.Flush() - if object.Usage != nil { - completionToken = int64(object.Usage.CompletionTokens) - promptToken = int64(object.Usage.PromptTokens) - requestID = object.Id - } + ctx.CustomRender().Flush() } // send done message - encodeErr := sse.Encode(ctx.Writer, sse.Event{Data: "[DONE]"}) + encodeErr := sse.Encode(ctx.CustomRender(), sse.Event{Data: "[DONE]"}) if encodeErr != nil { global.Logger.Error(logger.NewFields(ctx).WithMessage("encode response failed").WithData(encodeErr)) } - ctx.Writer.Flush() + ctx.CustomRender().Flush() + } - // parse streaming response - promptCostAmount := metadata.ModelPromptPrice.Mul(decimal.NewFromInt(promptToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit)) - completionCostAmount := metadata.ModelCompletionPrice.Mul(decimal.NewFromInt(completionToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit)) - balanceCost := promptCostAmount.Add(completionCostAmount).Mul(decimal.NewFromInt(-1)) - _, updateClientBalanceErr := global.OpenaiClientBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.ClientID, balanceCost, model.OpenaiClientBalanceActionConsumption) - _, updateUserBalanceErr := global.WhisperUserBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.UserID, balanceCost, model.WhisperUserBalanceActionConsumption) - updateRequestErr := global.OpenaiRequestDatabaseInstance.CreateOpenaiRequestRecord(ctx, &model.OpenaiRequest{ - ClientID: int64(metadata.ClientID), - ModelID: int64(metadata.ModelID), - UserID: int64(metadata.UserID), - RequestIP: ctx.ClientIP(), - RequestID: requestID, - TraceID: trace.GetTid(ctx), - PromptTokenUsage: int(promptToken), - CompletionTokenUsage: int(completionToken), - BalanceCost: balanceCost.Abs(), - }) - - for _, err := range []error{updateClientBalanceErr, updateUserBalanceErr, updateRequestErr} { - if err != nil { - global.Logger.Error(logger.NewFields(ctx).WithMessage("update response result failed").WithData(err)) - } + // consume success, update balances + promptCostAmount := metadata.ModelPromptPrice.Mul(decimal.NewFromInt(realPromptToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit)) + completionCostAmount := metadata.ModelCompletionPrice.Mul(decimal.NewFromInt(realCompletionToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit)) + balanceCost := promptCostAmount.Add(completionCostAmount).Mul(decimal.NewFromInt(-1)) + global.Logger.Info(logger.NewFields(ctx).WithMessage("costs calculated").WithData(map[string]any{"prompt_cost": promptCostAmount, "completion_cost": completionCostAmount, "balance_cost": balanceCost})) + + _, updateClientBalanceErr := global.OpenaiClientBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.ClientID, balanceCost, model.OpenaiClientBalanceActionConsumption) + _, updateUserBalanceErr := global.WhisperUserBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.UserID, balanceCost, model.WhisperUserBalanceActionConsumption) + updateRequestErr := global.OpenaiRequestDatabaseInstance.CreateOpenaiRequestRecord(ctx, &model.OpenaiRequest{ + ClientID: int64(metadata.ClientID), + ModelID: int64(metadata.ModelID), + UserID: int64(metadata.UserID), + RequestIP: ctx.ExtraParams().GetString(http.RemoteIPKey), + RequestID: requestID, + TraceID: trace.GetTid(ctx), + PromptTokenUsage: int(realPromptToken), + CompletionTokenUsage: int(realCompletionToken), + BalanceCost: balanceCost.Abs(), + }) + for _, err := range []error{updateClientBalanceErr, updateUserBalanceErr, updateRequestErr} { + if err != nil { + global.Logger.Error(logger.NewFields(ctx).WithMessage("update response result failed").WithData(err)) } } + + // return openai response + ctx.SetStatusCode(http.StatusOK) } func (srv *CompatibleService) ListModel(ctx http.Context[*openai.ListModelRequest, *openai.ListModelResponseBody]) { diff --git a/go.mod b/go.mod index f214c5b..848b892 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/alioth-center/akasha-whisper go 1.22.4 require ( - github.com/alioth-center/infrastructure v1.2.19 + github.com/alioth-center/infrastructure v1.2.20-0.20241012063141-64711b036676 github.com/bits-and-blooms/bloom/v3 v3.7.0 github.com/gin-contrib/sse v0.1.0 github.com/gin-gonic/gin v1.10.0 diff --git a/go.sum b/go.sum index a084f4f..f99c641 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/alioth-center/infrastructure v1.2.19 h1:lPMy8clZPA0S3VD/wNpufMHIC+0pkq2R26NmlHEk1ww= github.com/alioth-center/infrastructure v1.2.19/go.mod h1:QMr9jurGWQ30p0wG2IcatILimuSypnF0YtmyzD2PFmE= +github.com/alioth-center/infrastructure v1.2.20-0.20241012063141-64711b036676 h1:mOx56QD63D56yJdwQvl9MfePw8qF2JQORGfe4Rcj8Q0= +github.com/alioth-center/infrastructure v1.2.20-0.20241012063141-64711b036676/go.mod h1:QMr9jurGWQ30p0wG2IcatILimuSypnF0YtmyzD2PFmE= github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bits-and-blooms/bitset v1.14.2 h1:YXVoyPndbdvcEVcseEovVfp0qjJp7S+i5+xgp/Nfbdc= github.com/bits-and-blooms/bitset v1.14.2/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=