diff --git a/controller/rest/auth.go b/controller/rest/auth.go index 4fdacbc..9a89dd8 100644 --- a/controller/rest/auth.go +++ b/controller/rest/auth.go @@ -113,7 +113,7 @@ func CheckAuth(c *gin.Context) { c.Set(PasswordKey, info.Password) return } - if strings.HasPrefix(auth, "Basic") { + if strings.HasPrefix(auth, "Basic") && len(auth) > 6 { user, password, err := tools.DecodeBasic(auth[6:]) if err != nil { UnAuthResponse(c, logger, httperror.HTTP_INVALID_BASIC_AUTH) @@ -129,7 +129,7 @@ func CheckAuth(c *gin.Context) { }) c.Set(UserKey, user) c.Set(PasswordKey, password) - } else if strings.HasPrefix(auth, "Taosd") { + } else if strings.HasPrefix(auth, "Taosd") && len(auth) > 6 { user, password, err := DecodeDes(auth[6:]) if err != nil { UnAuthResponse(c, logger, httperror.HTTP_INVALID_TAOSD_AUTH) diff --git a/plugin/auth.go b/plugin/auth.go index acb9903..b5fbc87 100644 --- a/plugin/auth.go +++ b/plugin/auth.go @@ -1,7 +1,6 @@ package plugin import ( - "encoding/base64" "errors" "net/http" "strings" @@ -41,62 +40,13 @@ func Auth(errHandler func(c *gin.Context, code int, err error)) func(c *gin.Cont c.Set(PasswordKey, info.Password) return } - if strings.HasPrefix(auth, "Basic") { - b, err := base64.StdEncoding.DecodeString(auth[6:]) + if strings.HasPrefix(auth, "Basic") && len(auth) > 6 { + user, password, err := tools.DecodeBasic(auth[6:]) if err != nil { errHandler(c, http.StatusUnauthorized, err) c.Abort() return } - var user, password string - sl := strings.Split(string(b), ":") - - if len(sl) == 2 { - user = sl[0] - password = sl[1] - } else if len(sl) == 3 { - if sl[2] == "a" { - encodeData, err := base64.StdEncoding.DecodeString(sl[0]) - if err != nil { - errHandler(c, http.StatusUnauthorized, err) - c.Abort() - return - } - key, err := base64.StdEncoding.DecodeString(sl[1]) - if err != nil { - errHandler(c, http.StatusUnauthorized, err) - c.Abort() - return - } - if len(key) != 16 { - errHandler(c, http.StatusUnauthorized, errors.New("parse error")) - c.Abort() - return - } - authBytes, err := tools.AesDecrypt(encodeData, key) - if err != nil { - errHandler(c, http.StatusUnauthorized, err) - c.Abort() - return - } - a := strings.Split(string(authBytes), ":") - if len(a) != 2 { - errHandler(c, http.StatusUnauthorized, errors.New("parse error")) - c.Abort() - return - } - user = a[0] - password = a[1] - } else { - errHandler(c, http.StatusUnauthorized, errors.New("unknown auth type")) - c.Abort() - return - } - } else { - errHandler(c, http.StatusUnauthorized, errors.New("parse error")) - c.Abort() - return - } authCache.SetDefault(auth, &authInfo{ User: user, Password: password, @@ -121,31 +71,7 @@ func RegisterGenerateAuth(r gin.IRouter) { b.WriteString(user) b.WriteByte(':') b.WriteString(password) - keyBytes := make([]byte, 16) - maxLen := len(key) - if maxLen > 16 { - maxLen = 16 - } - for i := 0; i < maxLen; i++ { - keyBytes[i] = key[i] - } - d, err := tools.AesEncrypt(b.Bytes(), keyBytes) - if err != nil { - c.AbortWithStatus(http.StatusBadRequest) - return - } - l1 := make([]byte, base64.StdEncoding.EncodedLen(len(d))) - base64.StdEncoding.Encode(l1, d) - l2 := make([]byte, base64.StdEncoding.EncodedLen(len(keyBytes))) - base64.StdEncoding.Encode(l2, keyBytes) - buf := pool.BytesPoolGet() - buf.Write(l1) - buf.WriteByte(':') - buf.Write(l2) - buf.WriteByte(':') - buf.WriteString("a") - c.String(http.StatusOK, buf.String()) - pool.BytesPoolPut(buf) + c.String(http.StatusOK, b.String()) }) } diff --git a/plugin/influxdb/plugin.go b/plugin/influxdb/plugin.go index cfcd62e..82b7198 100644 --- a/plugin/influxdb/plugin.go +++ b/plugin/influxdb/plugin.go @@ -239,7 +239,7 @@ func getAuth(c *gin.Context) { auth := c.GetHeader("Authorization") if len(auth) != 0 { auth = strings.TrimSpace(auth) - if strings.HasPrefix(auth, "Basic") { + if strings.HasPrefix(auth, "Basic") && len(auth) > 6 { user, password, err := tools.DecodeBasic(auth[6:]) if err == nil { c.Set(plugin.UserKey, user) diff --git a/tools/basic.go b/tools/basic.go index f36c6e1..d9bdb7e 100644 --- a/tools/basic.go +++ b/tools/basic.go @@ -11,7 +11,7 @@ func DecodeBasic(auth string) (user, password string, err error) { if err != nil { return "", "", err } - sl := strings.Split(string(b), ":") + sl := strings.SplitN(string(b), ":", 2) if len(sl) != 2 { return "", "", errors.New("wrong basic auth") } diff --git a/tools/basic_test.go b/tools/basic_test.go index 3da0ae4..28b4548 100644 --- a/tools/basic_test.go +++ b/tools/basic_test.go @@ -34,7 +34,8 @@ func TestDecodeBasic(t *testing.T) { wantUser: "root", wantPassword: "taosdata", wantErr: false, - }, { + }, + { name: "wrong base64", args: args{ auth: "wrong base64", @@ -42,7 +43,8 @@ func TestDecodeBasic(t *testing.T) { wantUser: "", wantPassword: "", wantErr: true, - }, { + }, + { name: "wrong split", args: args{ auth: "cm9vdHRhb3NkYXRh", @@ -51,6 +53,15 @@ func TestDecodeBasic(t *testing.T) { wantPassword: "", wantErr: true, }, + { + name: "special char", + args: args{ + auth: "dGVzdDoxIXFAIyQlXiYqKCktXys9W117fTo7Pjw/fH4sLg==", + }, + wantUser: "test", + wantPassword: "1!q@#$%^&*()-_+=[]{}:;>