Skip to content

Commit 2f4a6e6

Browse files
authored
Merge pull request #42 from warjiang/feat-langchain
Feat langchain
2 parents 6affa55 + 90099e6 commit 2f4a6e6

File tree

11 files changed

+791
-89
lines changed

11 files changed

+791
-89
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ bin/
1919
release/
2020
docker-compose.yml
2121
dist/
22+
config/config.yaml

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Verified support projects:
1919
| -------------------------------------------------------- | ------ |
2020
| [chatgpt-web](https://github.com/Chanzhaoyu/chatgpt-web) ||
2121
| [chatbox](https://github.com/Bin-Huang/chatbox) ||
22+
| [langchain](https://python.langchain.com/en/latest/) ||
2223

2324
## Get Start
2425

@@ -56,11 +57,17 @@ API Key: This value can be found in the **Keys & Endpoint** section when examini
5657
### Use Docker
5758

5859
````shell
60+
# config by environment
5961
docker run -d -p 8080:8080 --name=azure-openai-proxy \
6062
--env AZURE_OPENAI_ENDPOINT=your_azure_endpoint \
6163
--env AZURE_OPENAI_API_VER=your_azure_api_ver \
6264
--env AZURE_OPENAI_MODEL_MAPPER=your_azure_deploy_mapper \
6365
stulzq/azure-openai-proxy:latest
66+
67+
# config by file
68+
docker run -d -p 8080:8080 --name=azure-openai-proxy \
69+
-v /path/to/config-file.yaml:/app/config/config.yaml \
70+
stulzq/azure-openai-proxy:latest
6471
````
6572

6673
Call API:

azure/init.go

+77-27
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package azure
22

33
import (
4+
"fmt"
5+
"github.com/spf13/viper"
46
"github.com/stulzq/azure-openai-proxy/constant"
7+
"github.com/stulzq/azure-openai-proxy/util"
58
"log"
69
"net/url"
7-
"os"
8-
"regexp"
910
"strings"
1011
)
1112

@@ -14,43 +15,92 @@ const (
1415
)
1516

1617
var (
17-
AzureOpenAIEndpoint = ""
18-
AzureOpenAIEndpointParse *url.URL
19-
20-
AzureOpenAIAPIVer = ""
21-
22-
AzureOpenAIModelMapper = map[string]string{
23-
"gpt-3.5-turbo": "gpt-35-turbo",
24-
}
25-
fallbackModelMapper = regexp.MustCompile(`[.:]`)
18+
C Config
19+
ModelDeploymentConfig = map[string]DeploymentConfig{}
2620
)
2721

28-
func Init() {
29-
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
30-
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
22+
func Init() error {
23+
var (
24+
apiVersion string
25+
endpoint string
26+
openaiModelMapper string
27+
err error
28+
)
3129

32-
if AzureOpenAIAPIVer == "" {
33-
AzureOpenAIAPIVer = "2023-03-15-preview"
30+
apiVersion = viper.GetString(constant.ENV_AZURE_OPENAI_API_VER)
31+
endpoint = viper.GetString(constant.ENV_AZURE_OPENAI_ENDPOINT)
32+
openaiModelMapper = viper.GetString(constant.ENV_AZURE_OPENAI_MODEL_MAPPER)
33+
if endpoint != "" && openaiModelMapper != "" {
34+
if apiVersion == "" {
35+
apiVersion = "2023-03-15-preview"
36+
}
37+
InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper)
38+
} else {
39+
if err = InitFromConfigFile(); err != nil {
40+
return err
41+
}
3442
}
3543

36-
var err error
37-
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
38-
if err != nil {
39-
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
44+
// ensure apiBase likes /v1
45+
apiBase := viper.GetString("api_base")
46+
if !strings.HasPrefix(apiBase, "/") {
47+
apiBase = "/" + apiBase
48+
}
49+
if strings.HasSuffix(apiBase, "/") {
50+
apiBase = apiBase[:len(apiBase)-1]
51+
}
52+
viper.Set("api_base", apiBase)
53+
log.Printf("apiBase is: %s", apiBase)
54+
for _, itemConfig := range C.DeploymentConfig {
55+
u, err := url.Parse(itemConfig.Endpoint)
56+
if err != nil {
57+
return fmt.Errorf("parse endpoint error: %w", err)
58+
}
59+
itemConfig.EndpointUrl = u
60+
ModelDeploymentConfig[itemConfig.ModelName] = itemConfig
4061
}
62+
return err
63+
}
4164

42-
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
43-
for _, pair := range strings.Split(v, ",") {
65+
func InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper string) {
66+
log.Println("Init from environment variables")
67+
if openaiModelMapper != "" {
68+
// openaiModelMapper example:
69+
// gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
70+
for _, pair := range strings.Split(openaiModelMapper, ",") {
4471
info := strings.Split(pair, "=")
4572
if len(info) != 2 {
4673
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
4774
}
48-
49-
AzureOpenAIModelMapper[info[0]] = info[1]
75+
modelName, deploymentName := info[0], info[1]
76+
ModelDeploymentConfig[modelName] = DeploymentConfig{
77+
DeploymentName: deploymentName,
78+
ModelName: modelName,
79+
Endpoint: endpoint,
80+
ApiKey: "",
81+
ApiVersion: apiVersion,
82+
}
5083
}
5184
}
85+
}
86+
87+
func InitFromConfigFile() error {
88+
log.Println("Init from config file")
89+
workDir := util.GetWorkdir()
90+
viper.SetConfigName("config")
91+
viper.SetConfigType("yaml")
92+
viper.AddConfigPath(fmt.Sprintf("%s/config", workDir))
93+
if err := viper.ReadInConfig(); err != nil {
94+
log.Printf("read config file error: %+v\n", err)
95+
return err
96+
}
5297

53-
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
54-
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
55-
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
98+
if err := viper.Unmarshal(&C); err != nil {
99+
log.Printf("unmarshal config file error: %+v\n", err)
100+
return err
101+
}
102+
for _, configItem := range C.DeploymentConfig {
103+
ModelDeploymentConfig[configItem.ModelName] = configItem
104+
}
105+
return nil
56106
}

azure/model.go

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package azure
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"github.com/pkg/errors"
7+
"log"
8+
"net/http"
9+
"net/url"
10+
"path"
11+
"strings"
12+
"text/template"
13+
)
14+
15+
type DeploymentConfig struct {
16+
DeploymentName string `yaml:"deployment_name" json:"deployment_name" mapstructure:"deployment_name"` // azure openai deployment name
17+
ModelName string `yaml:"model_name" json:"model_name" mapstructure:"model_name"` // corresponding model name in openai
18+
Endpoint string `yaml:"endpoint" json:"endpoint" mapstructure:"endpoint"` // deployment endpoint
19+
ApiKey string `yaml:"api_key" json:"api_key" mapstructure:"api_key"` // secrect key1 or 2
20+
ApiVersion string `yaml:"api_version" json:"api_version" mapstructure:"api_version"` // deployment version, not required
21+
EndpointUrl *url.URL // url.URL form deployment endpoint
22+
}
23+
24+
type Config struct {
25+
ApiBase string `yaml:"api_base" mapstructure:"api_base"` // if you use openai、langchain as sdk, it will be useful
26+
DeploymentConfig []DeploymentConfig `yaml:"deployment_config" mapstructure:"deployment_config"` // deployment config
27+
}
28+
29+
type RequestConverter interface {
30+
Name() string
31+
Convert(req *http.Request, config *DeploymentConfig) (*http.Request, error)
32+
}
33+
34+
type StripPrefixConverter struct {
35+
Prefix string
36+
}
37+
38+
func (c *StripPrefixConverter) Name() string {
39+
return "StripPrefix"
40+
}
41+
func (c *StripPrefixConverter) Convert(req *http.Request, config *DeploymentConfig) (*http.Request, error) {
42+
req.Host = config.EndpointUrl.Host
43+
req.URL.Scheme = config.EndpointUrl.Scheme
44+
req.URL.Host = config.EndpointUrl.Host
45+
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", config.DeploymentName), strings.Replace(req.URL.Path, c.Prefix+"/", "/", 1))
46+
req.URL.RawPath = req.URL.EscapedPath()
47+
48+
query := req.URL.Query()
49+
query.Add("api-version", config.ApiVersion)
50+
req.URL.RawQuery = query.Encode()
51+
return req, nil
52+
}
53+
func NewStripPrefixConverter(prefix string) *StripPrefixConverter {
54+
return &StripPrefixConverter{
55+
Prefix: prefix,
56+
}
57+
}
58+
59+
type TemplateConverter struct {
60+
Tpl string
61+
Tempalte *template.Template
62+
}
63+
64+
func (c *TemplateConverter) Name() string {
65+
return "Template"
66+
}
67+
func (c *TemplateConverter) Convert(req *http.Request, config *DeploymentConfig) (*http.Request, error) {
68+
data := map[string]interface{}{
69+
"DeploymentName": config.DeploymentName,
70+
"ModelName": config.ModelName,
71+
"Endpoint": config.Endpoint,
72+
"ApiKey": config.ApiKey,
73+
"ApiVersion": config.ApiVersion,
74+
}
75+
buff := new(bytes.Buffer)
76+
if err := c.Tempalte.Execute(buff, data); err != nil {
77+
return req, errors.Wrap(err, "template execute error")
78+
}
79+
80+
req.Host = config.EndpointUrl.Host
81+
req.URL.Scheme = config.EndpointUrl.Scheme
82+
req.URL.Host = config.EndpointUrl.Host
83+
req.URL.Path = buff.String()
84+
req.URL.RawPath = req.URL.EscapedPath()
85+
86+
query := req.URL.Query()
87+
query.Add("api-version", config.ApiVersion)
88+
req.URL.RawQuery = query.Encode()
89+
return req, nil
90+
}
91+
func NewTemplateConverter(tpl string) *TemplateConverter {
92+
_template, err := template.New("template").Parse(tpl)
93+
if err != nil {
94+
log.Fatalf("template parse error: %s", err.Error())
95+
}
96+
return &TemplateConverter{
97+
Tpl: tpl,
98+
Tempalte: _template,
99+
}
100+
}

azure/proxy.go

+43-28
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@ import (
88
"log"
99
"net/http"
1010
"net/http/httputil"
11-
"path"
1211
"strings"
1312

1413
"github.com/bytedance/sonic"
1514
"github.com/gin-gonic/gin"
1615
"github.com/pkg/errors"
1716
)
1817

18+
func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc {
19+
return func(c *gin.Context) {
20+
Proxy(c, requestConverter)
21+
}
22+
}
23+
1924
// Proxy Azure OpenAI
20-
func Proxy(c *gin.Context) {
25+
func Proxy(c *gin.Context, requestConverter RequestConverter) {
2126
if c.Request.Method == http.MethodOptions {
2227
c.Header("Access-Control-Allow-Origin", "*")
2328
c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST")
@@ -34,38 +39,48 @@ func Proxy(c *gin.Context) {
3439
body, _ := io.ReadAll(req.Body)
3540
req.Body = io.NopCloser(bytes.NewBuffer(body))
3641

37-
// get model from body
38-
model, err := sonic.Get(body, "model")
39-
if err != nil {
40-
util.SendError(c, errors.Wrap(err, "get model error"))
41-
return
42+
// get model from url params or body
43+
model := c.Param("model")
44+
if model == "" {
45+
_model, err := sonic.Get(body, "model")
46+
if err != nil {
47+
util.SendError(c, errors.Wrap(err, "get model error"))
48+
return
49+
}
50+
_modelStr, err := _model.String()
51+
if err != nil {
52+
util.SendError(c, errors.Wrap(err, "get model name error"))
53+
return
54+
}
55+
model = _modelStr
4256
}
4357

4458
// get deployment from request
45-
deployment, err := model.String()
59+
deployment, err := GetDeploymentByModel(model)
4660
if err != nil {
47-
util.SendError(c, errors.Wrap(err, "get deployment error"))
61+
util.SendError(c, err)
4862
return
4963
}
50-
deployment = GetDeploymentByModel(deployment)
5164

52-
// get auth token from header
53-
rawToken := req.Header.Get("Authorization")
54-
token := strings.TrimPrefix(rawToken, "Bearer ")
65+
// get auth token from header or deployemnt config
66+
token := deployment.ApiKey
67+
if token == "" {
68+
rawToken := req.Header.Get("Authorization")
69+
token = strings.TrimPrefix(rawToken, "Bearer ")
70+
}
71+
if token == "" {
72+
util.SendError(c, errors.New("token is empty"))
73+
return
74+
}
5575
req.Header.Set(AuthHeaderKey, token)
5676
req.Header.Del("Authorization")
5777

5878
originURL := req.URL.String()
59-
req.Host = AzureOpenAIEndpointParse.Host
60-
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
61-
req.URL.Host = AzureOpenAIEndpointParse.Host
62-
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
63-
req.URL.RawPath = req.URL.EscapedPath()
64-
65-
query := req.URL.Query()
66-
query.Add("api-version", AzureOpenAIAPIVer)
67-
req.URL.RawQuery = query.Encode()
68-
79+
req, err = requestConverter.Convert(req, deployment)
80+
if err != nil {
81+
util.SendError(c, errors.Wrap(err, "convert request error"))
82+
return
83+
}
6984
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
7085
}
7186

@@ -80,10 +95,10 @@ func Proxy(c *gin.Context) {
8095
}
8196
}
8297

83-
func GetDeploymentByModel(model string) string {
84-
if v, ok := AzureOpenAIModelMapper[model]; ok {
85-
return v
98+
func GetDeploymentByModel(model string) (*DeploymentConfig, error) {
99+
deploymentConfig, exist := ModelDeploymentConfig[model]
100+
if !exist {
101+
return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model))
86102
}
87-
88-
return fallbackModelMapper.ReplaceAllString(model, "")
103+
return &deploymentConfig, nil
89104
}

cmd/main.go

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"flag"
66
"fmt"
7+
"github.com/spf13/viper"
78
"github.com/stulzq/azure-openai-proxy/azure"
89
"log"
910
"net/http"
@@ -22,6 +23,7 @@ var (
2223
)
2324

2425
func main() {
26+
viper.AutomaticEnv()
2527
parseFlag()
2628

2729
azure.Init()

0 commit comments

Comments
 (0)