forked from snowflakedb/gosnowflake
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpriv_key_test.go
117 lines (106 loc) · 3.62 KB
/
priv_key_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
package gosnowflake
// For compile concern, should any newly added variables or functions here must also be added with same
// name or signature but with default or empty content in the priv_key_test.go(See addParseDSNTest)
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/pem"
"fmt"
"os"
"testing"
)
// helper function to generate PKCS8 encoded base64 string of a private key
func generatePKCS8StringSupress(key *rsa.PrivateKey) string {
// Error would only be thrown when the private key type is not supported
// We would be safe as long as we are using rsa.PrivateKey
tmpBytes, _ := x509.MarshalPKCS8PrivateKey(key)
privKeyPKCS8 := base64.URLEncoding.EncodeToString(tmpBytes)
return privKeyPKCS8
}
// helper function to generate PKCS1 encoded base64 string of a private key
func generatePKCS1String(key *rsa.PrivateKey) string {
tmpBytes := x509.MarshalPKCS1PrivateKey(key)
privKeyPKCS1 := base64.URLEncoding.EncodeToString(tmpBytes)
return privKeyPKCS1
}
// helper function to set up private key for testing
func setupPrivateKey() {
env := func(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
privKeyPath := env("SNOWFLAKE_TEST_PRIVATE_KEY", "")
if privKeyPath == "" {
customPrivateKey = false
testPrivKey, _ = rsa.GenerateKey(rand.Reader, 2048)
} else {
// path to the DER file
customPrivateKey = true
data, _ := os.ReadFile(privKeyPath)
block, _ := pem.Decode(data)
if block == nil || block.Type != "PRIVATE KEY" {
panic(fmt.Sprintf("%v is not a public key in PEM format.", privKeyPath))
}
privKey, _ := x509.ParsePKCS8PrivateKey(block.Bytes)
testPrivKey = privKey.(*rsa.PrivateKey)
}
}
// Helper function to add encoded private key to dsn
func appendPrivateKeyString(dsn *string, key *rsa.PrivateKey) string {
var b bytes.Buffer
b.WriteString(*dsn)
b.WriteString(fmt.Sprintf("&authenticator=%v", AuthTypeJwt.String()))
b.WriteString(fmt.Sprintf("&privateKey=%s", generatePKCS8StringSupress(key)))
return b.String()
}
// Integration test for the JWT authentication function
func TestJWTAuthentication(t *testing.T) {
// For private key generated on the fly, we want to load the public key to the server first
if !customPrivateKey {
db := openDB(t)
// Load server's public key to database
pubKeyByte, err := x509.MarshalPKIXPublicKey(testPrivKey.Public())
if err != nil {
t.Fatalf("error marshaling public key: %s", err.Error())
}
if _, err = db.Exec("USE ROLE ACCOUNTADMIN"); err != nil {
t.Fatalf("error changin role: %s", err.Error())
}
encodedKey := base64.StdEncoding.EncodeToString(pubKeyByte)
if _, err = db.Exec(fmt.Sprintf("ALTER USER %v set rsa_public_key='%v'", username, encodedKey)); err != nil {
t.Fatalf("error setting server's public key: %s", err.Error())
}
db.Close()
}
// Test that a valid private key can pass
jwtDSN := appendPrivateKeyString(&dsn, testPrivKey)
db, err := sql.Open("snowflake", jwtDSN)
if err != nil {
t.Fatalf("error creating a connection object: %s", err.Error())
}
if _, err = db.Exec("SELECT 1"); err != nil {
t.Fatalf("error executing: %s", err.Error())
}
db.Close()
// Test that an invalid private key cannot pass
invalidPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Error(err)
}
jwtDSN = appendPrivateKeyString(&dsn, invalidPrivateKey)
db, err = sql.Open("snowflake", jwtDSN)
if err != nil {
t.Error(err)
}
if _, err = db.Exec("SELECT 1"); err == nil {
t.Fatalf("An invalid jwt token can pass")
}
db.Close()
}