Skip to content

Commit

Permalink
feat(compatible): 优化 event-stream 实现,使用新版 http engine
Browse files Browse the repository at this point in the history
  • Loading branch information
sunist-c committed Oct 12, 2024
1 parent 2188234 commit aa7af56
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 128 deletions.
5 changes: 0 additions & 5 deletions app/api/compatible.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"github.com/alioth-center/akasha-whisper/app/service"
"github.com/alioth-center/infrastructure/network/http"
"github.com/alioth-center/infrastructure/thirdparty/openai"
"github.com/gin-gonic/gin"
)

var CompatibleApi compatibleApiImpl
Expand All @@ -17,10 +16,6 @@ func (impl compatibleApiImpl) CompleteChat() http.Chain[*openai.CompleteChatRequ
return http.NewChain(impl.service.ChatComplete)
}

func (impl compatibleApiImpl) StreamingCompleteChat() []gin.HandlerFunc {
return []gin.HandlerFunc{impl.service.StreamingChatCompletion}
}

func (impl compatibleApiImpl) Embedding() http.Chain[*openai.EmbeddingRequestBody, *openai.EmbeddingResponseBody] {
return http.NewChain(impl.service.EmbeddingAuthorize, impl.service.Embedding)
}
Expand Down
2 changes: 1 addition & 1 deletion app/router/compatible.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ var compatibleRouter = http.NewRouter("v1")
var OpenAiCompatibleRouterGroup = []http.EndPointInterface{
http.NewEndPointBuilder[*openai.CompleteChatRequestBody, *openai.CompleteChatResponseBody]().
SetNecessaryHeaders("Authorization").
SetGinMiddlewares(api.CompatibleApi.StreamingCompleteChat()...).
SetCustomRender(true).
SetHandlerChain(api.CompatibleApi.CompleteChat()).
SetAllowMethods(http.POST).
SetRouter(compatibleRouter.Group("/chat/completions")).
Expand Down
204 changes: 84 additions & 120 deletions app/service/compatible.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
package service

import (
"bytes"
"context"
"encoding/json"
"io"
"strings"
"time"

"github.com/gin-contrib/sse"
"github.com/gin-gonic/gin"

"github.com/alioth-center/akasha-whisper/app/global"
"github.com/alioth-center/akasha-whisper/app/model"
"github.com/alioth-center/infrastructure/logger"
"github.com/alioth-center/infrastructure/network/http"
"github.com/alioth-center/infrastructure/thirdparty/openai"
"github.com/alioth-center/infrastructure/trace"
"github.com/alioth-center/infrastructure/utils/values"
"github.com/gin-contrib/sse"
"github.com/pkg/errors"
"github.com/shopspring/decimal"
)
Expand All @@ -44,21 +40,22 @@ func (srv *CompatibleService) ChatComplete(ctx http.Context[*openai.CompleteChat
ctx.Abort()
return
} else if getErr != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("get available client failed").WithData(getErr))
ctx.SetStatusCode(http.StatusInternalServerError)
ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", getErr.Error())))
ctx.Abort()
return
}

// complete chat
response, executeErr := client.CompleteChat(ctx, openai.CompleteChatRequest{
realPromptToken, realCompletionToken, requestID := int64(0), int64(0), ""
openaiRequest := openai.CompleteChatRequest{
Body: openai.CompleteChatRequestBody{
Model: request.Model,
Messages: request.Messages,
Temperature: request.Temperature,
TopP: request.TopP,
N: request.N,
Stream: false,
Stream: request.Stream,
MaxTokens: min(request.MaxTokens, global.Config.App.MaxToken),
PresencePenalty: request.PresencePenalty,
FrequencyPenalty: request.FrequencyPenalty,
Expand All @@ -72,148 +69,115 @@ func (srv *CompatibleService) ChatComplete(ctx http.Context[*openai.CompleteChat
Tools: request.Tools,
ToolChoice: request.ToolChoice,
},
})
if executeErr != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error())))
ctx.Abort()
return
}

// consume success, update balances
promptCostAmount := metadata.ModelPromptPrice.Mul(decimal.NewFromInt(int64(response.Usage.PromptTokens))).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit))
completionCostAmount := metadata.ModelCompletionPrice.Mul(decimal.NewFromInt(int64(response.Usage.CompletionTokens))).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit))
balanceCost := promptCostAmount.Add(completionCostAmount).Mul(decimal.NewFromInt(-1))
_, updateClientBalanceErr := global.OpenaiClientBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.ClientID, balanceCost, model.OpenaiClientBalanceActionConsumption)
_, updateUserBalanceErr := global.WhisperUserBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.UserID, balanceCost, model.WhisperUserBalanceActionConsumption)
updateRequestErr := global.OpenaiRequestDatabaseInstance.CreateOpenaiRequestRecord(ctx, &model.OpenaiRequest{
ClientID: int64(metadata.ClientID),
ModelID: int64(metadata.ModelID),
UserID: int64(metadata.UserID),
RequestIP: ctx.ExtraParams().GetString(http.RemoteIPKey),
RequestID: response.ID,
TraceID: trace.GetTid(ctx),
PromptTokenUsage: response.Usage.PromptTokens,
CompletionTokenUsage: response.Usage.CompletionTokens,
BalanceCost: balanceCost.Abs(),
})
for _, err := range []error{updateClientBalanceErr, updateUserBalanceErr, updateRequestErr} {
if err != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("update response result failed").WithData(err))
if !request.Stream {
// complete chat without text stream
response, executeErr := client.CompleteChat(ctx, openaiRequest)
if executeErr != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("complete chat failed").WithData(executeErr))
ctx.SetStatusCode(http.StatusInternalServerError)
ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error())))
ctx.Abort()
return
}
}

// return openai response
ctx.SetStatusCode(http.StatusOK)
ctx.SetResponse(&response)
}

func (srv *CompatibleService) StreamingChatCompletion(ctx *gin.Context) {
// check api key available
apiKey := ctx.GetHeader(http.HeaderAuthorization)
exist, allowIPs, err := CheckApiKeyAvailable(ctx, apiKey)
if err != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("check api key available failed").WithData(err))
ctx.AbortWithStatusJSON(http.StatusInternalServerError, srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", err.Error())))
return
}
if !exist {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, srv.buildErrorChatCompleteResponse(ctx, "unauthorized"))
return
}

// check allow ip
if !CheckAllowIP(ctx, ctx.ClientIP(), strings.Split(allowIPs, ",")) {
ctx.AbortWithStatusJSON(http.StatusForbidden, srv.buildErrorChatCompleteResponse(ctx, "ip forbidden"))
return
}
realPromptToken, realCompletionToken, requestID = int64(response.Usage.PromptTokens), int64(response.Usage.CompletionTokens), response.ID

request := &openai.CompleteChatRequestBody{}
bindErr := ctx.ShouldBindJSON(request)
if bindErr != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": bindErr.Error()})
}
writeBack, _ := json.Marshal(request)
ctx.Request.Body = io.NopCloser(bytes.NewBuffer(writeBack))

if request.Stream {
defer ctx.Abort()
inputMessages := make([]string, len(request.Messages))
for i, message := range request.Messages {
inputMessages[i] = message.GetStringContent()
// marshal response to json
responseJson, marshalErr := json.Marshal(response)
if marshalErr != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("marshal response failed").WithData(marshalErr))
ctx.SetStatusCode(http.StatusInternalServerError)
ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", marshalErr.Error())))
ctx.Abort()
return
}
promptToken := CalculatePromptToken(inputMessages...)

// get available openai client
client, metadata, getErr := GetAvailableClient(ctx, apiKey, request.Model, promptToken, "chat")
if getErr != nil && errors.Is(getErr, ErrorNoAvailableClient) {
ctx.AbortWithStatusJSON(http.StatusForbidden, srv.buildErrorChatCompleteResponse(ctx, "no available client"))
return
} else if getErr != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", getErr.Error())))
// set response header
ctx.CustomRender().Header().Set("Cache-Control", "no-cache")
ctx.CustomRender().Header().Set(http.HeaderContentType, http.ContentTypeJson)
ctx.CustomRender().WriteHeaderNow()

// write response
_, writeErr := ctx.CustomRender().Write(responseJson)
if writeErr != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("write response failed").WithData(writeErr))
ctx.SetStatusCode(http.StatusInternalServerError)
ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", writeErr.Error())))
ctx.Abort()
return
}
} else {
// complete chat with text stream
openaiRequest.Body.StreamOptions = json.RawMessage(`{"include_usage": true}`)

request.StreamOptions = json.RawMessage(`{"include_usage": true}`)
response, executeErr := client.CompleteStreamingChat(ctx, openai.CompleteChatRequest{Body: *request})
response, executeErr := client.CompleteStreamingChat(ctx, openaiRequest)
if executeErr != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error())))
global.Logger.Error(logger.NewFields(ctx).WithMessage("complete streaming chat failed").WithData(executeErr))
ctx.SetStatusCode(http.StatusInternalServerError)
ctx.SetResponse(srv.buildErrorChatCompleteResponse(ctx, values.BuildStrings("internal server error: ", executeErr.Error())))
ctx.Abort()
return
}

ctx.Header(http.HeaderContentType, "text/event-stream")
ctx.Header("Cache-Control", "no-cache")
ctx.Header("Transfer-Encoding", "chunked")
ctx.Header("Connection", "keep-alive")

completionToken, requestID := int64(0), ""
// set response header
ctx.CustomRender().Header().Set(http.HeaderContentType, "text/event-stream")
ctx.CustomRender().Header().Set("Cache-Control", "no-cache")
ctx.CustomRender().Header().Set("Transfer-Encoding", "chunked")
ctx.CustomRender().Header().Set("Connection", "keep-alive")
ctx.CustomRender().WriteHeaderNow()

// parse streaming response
for object := range response {
encodeErr := sse.Encode(ctx.Writer, sse.Event{Data: object})
if object.Usage != nil {
realPromptToken, realCompletionToken, requestID = int64(object.Usage.PromptTokens), int64(object.Usage.CompletionTokens), object.Id
}

encodeErr := sse.Encode(ctx.CustomRender(), sse.Event{Data: object})
if encodeErr != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("encode response failed").WithData(encodeErr))
continue
}
ctx.Writer.Flush()

if object.Usage != nil {
completionToken = int64(object.Usage.CompletionTokens)
promptToken = int64(object.Usage.PromptTokens)
requestID = object.Id
}
ctx.CustomRender().Flush()
}

// send done message
encodeErr := sse.Encode(ctx.Writer, sse.Event{Data: "[DONE]"})
encodeErr := sse.Encode(ctx.CustomRender(), sse.Event{Data: "[DONE]"})
if encodeErr != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("encode response failed").WithData(encodeErr))
}
ctx.Writer.Flush()
ctx.CustomRender().Flush()
}

// parse streaming response
promptCostAmount := metadata.ModelPromptPrice.Mul(decimal.NewFromInt(promptToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit))
completionCostAmount := metadata.ModelCompletionPrice.Mul(decimal.NewFromInt(completionToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit))
balanceCost := promptCostAmount.Add(completionCostAmount).Mul(decimal.NewFromInt(-1))
_, updateClientBalanceErr := global.OpenaiClientBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.ClientID, balanceCost, model.OpenaiClientBalanceActionConsumption)
_, updateUserBalanceErr := global.WhisperUserBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.UserID, balanceCost, model.WhisperUserBalanceActionConsumption)
updateRequestErr := global.OpenaiRequestDatabaseInstance.CreateOpenaiRequestRecord(ctx, &model.OpenaiRequest{
ClientID: int64(metadata.ClientID),
ModelID: int64(metadata.ModelID),
UserID: int64(metadata.UserID),
RequestIP: ctx.ClientIP(),
RequestID: requestID,
TraceID: trace.GetTid(ctx),
PromptTokenUsage: int(promptToken),
CompletionTokenUsage: int(completionToken),
BalanceCost: balanceCost.Abs(),
})

for _, err := range []error{updateClientBalanceErr, updateUserBalanceErr, updateRequestErr} {
if err != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("update response result failed").WithData(err))
}
// consume success, update balances
promptCostAmount := metadata.ModelPromptPrice.Mul(decimal.NewFromInt(realPromptToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit))
completionCostAmount := metadata.ModelCompletionPrice.Mul(decimal.NewFromInt(realCompletionToken)).Div(decimal.NewFromInt(global.Config.App.PriceTokenUnit))
balanceCost := promptCostAmount.Add(completionCostAmount).Mul(decimal.NewFromInt(-1))
global.Logger.Info(logger.NewFields(ctx).WithMessage("costs calculated").WithData(map[string]any{"prompt_cost": promptCostAmount, "completion_cost": completionCostAmount, "balance_cost": balanceCost}))

_, updateClientBalanceErr := global.OpenaiClientBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.ClientID, balanceCost, model.OpenaiClientBalanceActionConsumption)
_, updateUserBalanceErr := global.WhisperUserBalanceDatabaseInstance.CreateBalanceRecord(ctx, metadata.UserID, balanceCost, model.WhisperUserBalanceActionConsumption)
updateRequestErr := global.OpenaiRequestDatabaseInstance.CreateOpenaiRequestRecord(ctx, &model.OpenaiRequest{
ClientID: int64(metadata.ClientID),
ModelID: int64(metadata.ModelID),
UserID: int64(metadata.UserID),
RequestIP: ctx.ExtraParams().GetString(http.RemoteIPKey),
RequestID: requestID,
TraceID: trace.GetTid(ctx),
PromptTokenUsage: int(realPromptToken),
CompletionTokenUsage: int(realCompletionToken),
BalanceCost: balanceCost.Abs(),
})
for _, err := range []error{updateClientBalanceErr, updateUserBalanceErr, updateRequestErr} {
if err != nil {
global.Logger.Error(logger.NewFields(ctx).WithMessage("update response result failed").WithData(err))
}
}

// return openai response
ctx.SetStatusCode(http.StatusOK)
}

func (srv *CompatibleService) ListModel(ctx http.Context[*openai.ListModelRequest, *openai.ListModelResponseBody]) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/alioth-center/akasha-whisper
go 1.22.4

require (
github.com/alioth-center/infrastructure v1.2.19
github.com/alioth-center/infrastructure v1.2.20-0.20241012063141-64711b036676
github.com/bits-and-blooms/bloom/v3 v3.7.0
github.com/gin-contrib/sse v0.1.0
github.com/gin-gonic/gin v1.10.0
Expand Down
3 changes: 2 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/alioth-center/infrastructure v1.2.19 h1:lPMy8clZPA0S3VD/wNpufMHIC+0pkq2R26NmlHEk1ww=
github.com/alioth-center/infrastructure v1.2.19/go.mod h1:QMr9jurGWQ30p0wG2IcatILimuSypnF0YtmyzD2PFmE=
github.com/alioth-center/infrastructure v1.2.20-0.20241012063141-64711b036676 h1:mOx56QD63D56yJdwQvl9MfePw8qF2JQORGfe4Rcj8Q0=
github.com/alioth-center/infrastructure v1.2.20-0.20241012063141-64711b036676/go.mod h1:QMr9jurGWQ30p0wG2IcatILimuSypnF0YtmyzD2PFmE=
github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/bits-and-blooms/bitset v1.14.2 h1:YXVoyPndbdvcEVcseEovVfp0qjJp7S+i5+xgp/Nfbdc=
github.com/bits-and-blooms/bitset v1.14.2/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
Expand Down

0 comments on commit aa7af56

Please sign in to comment.