1
1
package azure
2
2
3
3
import (
4
+ "fmt"
5
+ "github.com/spf13/viper"
4
6
"github.com/stulzq/azure-openai-proxy/constant"
7
+ "github.com/stulzq/azure-openai-proxy/util"
5
8
"log"
6
9
"net/url"
7
- "os"
8
- "regexp"
9
10
"strings"
10
11
)
11
12
@@ -14,43 +15,92 @@ const (
14
15
)
15
16
16
17
var (
17
- AzureOpenAIEndpoint = ""
18
- AzureOpenAIEndpointParse * url.URL
19
-
20
- AzureOpenAIAPIVer = ""
21
-
22
- AzureOpenAIModelMapper = map [string ]string {
23
- "gpt-3.5-turbo" : "gpt-35-turbo" ,
24
- }
25
- fallbackModelMapper = regexp .MustCompile (`[.:]` )
18
+ C Config
19
+ ModelDeploymentConfig = map [string ]DeploymentConfig {}
26
20
)
27
21
28
- func Init () {
29
- AzureOpenAIAPIVer = os .Getenv (constant .ENV_AZURE_OPENAI_API_VER )
30
- AzureOpenAIEndpoint = os .Getenv (constant .ENV_AZURE_OPENAI_ENDPOINT )
22
+ func Init () error {
23
+ var (
24
+ apiVersion string
25
+ endpoint string
26
+ openaiModelMapper string
27
+ err error
28
+ )
31
29
32
- if AzureOpenAIAPIVer == "" {
33
- AzureOpenAIAPIVer = "2023-03-15-preview"
30
+ apiVersion = viper .GetString (constant .ENV_AZURE_OPENAI_API_VER )
31
+ endpoint = viper .GetString (constant .ENV_AZURE_OPENAI_ENDPOINT )
32
+ openaiModelMapper = viper .GetString (constant .ENV_AZURE_OPENAI_MODEL_MAPPER )
33
+ if endpoint != "" && openaiModelMapper != "" {
34
+ if apiVersion == "" {
35
+ apiVersion = "2023-03-15-preview"
36
+ }
37
+ InitFromEnvironmentVariables (apiVersion , endpoint , openaiModelMapper )
38
+ } else {
39
+ if err = InitFromConfigFile (); err != nil {
40
+ return err
41
+ }
34
42
}
35
43
36
- var err error
37
- AzureOpenAIEndpointParse , err = url .Parse (AzureOpenAIEndpoint )
38
- if err != nil {
39
- log .Fatal ("parse AzureOpenAIEndpoint error: " , err )
44
+ // ensure apiBase likes /v1
45
+ apiBase := viper .GetString ("api_base" )
46
+ if ! strings .HasPrefix (apiBase , "/" ) {
47
+ apiBase = "/" + apiBase
48
+ }
49
+ if strings .HasSuffix (apiBase , "/" ) {
50
+ apiBase = apiBase [:len (apiBase )- 1 ]
51
+ }
52
+ viper .Set ("api_base" , apiBase )
53
+ log .Printf ("apiBase is: %s" , apiBase )
54
+ for _ , itemConfig := range C .DeploymentConfig {
55
+ u , err := url .Parse (itemConfig .Endpoint )
56
+ if err != nil {
57
+ return fmt .Errorf ("parse endpoint error: %w" , err )
58
+ }
59
+ itemConfig .EndpointUrl = u
60
+ ModelDeploymentConfig [itemConfig .ModelName ] = itemConfig
40
61
}
62
+ return err
63
+ }
41
64
42
- if v := os .Getenv (constant .ENV_AZURE_OPENAI_MODEL_MAPPER ); v != "" {
43
- for _ , pair := range strings .Split (v , "," ) {
65
+ func InitFromEnvironmentVariables (apiVersion , endpoint , openaiModelMapper string ) {
66
+ log .Println ("Init from environment variables" )
67
+ if openaiModelMapper != "" {
68
+ // openaiModelMapper example:
69
+ // gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
70
+ for _ , pair := range strings .Split (openaiModelMapper , "," ) {
44
71
info := strings .Split (pair , "=" )
45
72
if len (info ) != 2 {
46
73
log .Fatalf ("error parsing %s, invalid value %s" , constant .ENV_AZURE_OPENAI_MODEL_MAPPER , pair )
47
74
}
48
-
49
- AzureOpenAIModelMapper [info [0 ]] = info [1 ]
75
+ modelName , deploymentName := info [0 ], info [1 ]
76
+ ModelDeploymentConfig [modelName ] = DeploymentConfig {
77
+ DeploymentName : deploymentName ,
78
+ ModelName : modelName ,
79
+ Endpoint : endpoint ,
80
+ ApiKey : "" ,
81
+ ApiVersion : apiVersion ,
82
+ }
50
83
}
51
84
}
85
+ }
86
+
87
+ func InitFromConfigFile () error {
88
+ log .Println ("Init from config file" )
89
+ workDir := util .GetWorkdir ()
90
+ viper .SetConfigName ("config" )
91
+ viper .SetConfigType ("yaml" )
92
+ viper .AddConfigPath (fmt .Sprintf ("%s/config" , workDir ))
93
+ if err := viper .ReadInConfig (); err != nil {
94
+ log .Printf ("read config file error: %+v\n " , err )
95
+ return err
96
+ }
52
97
53
- log .Println ("AzureOpenAIAPIVer: " , AzureOpenAIAPIVer )
54
- log .Println ("AzureOpenAIEndpoint: " , AzureOpenAIEndpoint )
55
- log .Println ("AzureOpenAIModelMapper: " , AzureOpenAIModelMapper )
98
+ if err := viper .Unmarshal (& C ); err != nil {
99
+ log .Printf ("unmarshal config file error: %+v\n " , err )
100
+ return err
101
+ }
102
+ for _ , configItem := range C .DeploymentConfig {
103
+ ModelDeploymentConfig [configItem .ModelName ] = configItem
104
+ }
105
+ return nil
56
106
}
0 commit comments