Skip to content

Commit

Permalink
- adjusted verification for some siws parameters.
Browse files Browse the repository at this point in the history
- mounted the /nonce endpoint to the router.
- switched nonce to random OTP (slashes not allowed in wallet adapters).
- adjusted migration file to add the nonce tables.
  • Loading branch information
Bewinxed committed Jan 29, 2025
1 parent 54fdc0a commit efb21e7
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 129 deletions.
2 changes: 0 additions & 2 deletions example.env
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI="http://localhost:9999/callback"

# EIP-4361 OAuth config
GOTRUE_EXTERNAL_WEB3_ENABLED="true"
GOTRUE_EXTERNAL_WEB3_STATEMENT="Sign this message to verify your identity"
GOTRUE_EXTERNAL_WEB3_VERSION="1"
GOTRUE_EXTERNAL_WEB3_TIMEOUT="300s"
GOTRUE_EXTERNAL_WEB3_DOMAIN="localhost:9999"

Expand Down
4 changes: 4 additions & 0 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
})

r.Route("/", func(r *router) {


r.Use(api.isValidExternalHost)

r.Get("/nonce", api.GetNonce)

r.Get("/settings", api.Settings)

r.Get("/authorize", api.ExternalProviderRedirect)
Expand Down
4 changes: 3 additions & 1 deletion internal/api/provider/eip4361.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"time"
"log"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"
Expand Down Expand Up @@ -72,6 +73,7 @@ func (p *Web3Provider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Use
// VerifySignedMessage verifies a signed Web3 message based on the blockchain
func (p *Web3Provider) VerifySignedMessage(msg *SignedMessage) (*UserProvidedData, error) {
chain, ok := p.chains[msg.Chain]
log.Printf("Verifying supported blockchain: %s", msg.Chain)
if !ok {
return nil, fmt.Errorf("unsupported blockchain: %s", msg.Chain)
}
Expand Down Expand Up @@ -129,7 +131,7 @@ func (p *Web3Provider) verifySolanaSignature(msg *SignedMessage) error {
}

if err := crypto.VerifySIWS(msg.Message, sigBytes, parsedMessage, params); err != nil {
return fmt.Errorf("SIWS verification failed: %w", err)
return err
}

return nil
Expand Down
69 changes: 44 additions & 25 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package api

import (
"context"
"database/sql"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"time"

"fmt"

"github.com/gofrs/uuid"
"github.com/golang-jwt/jwt/v5"
"github.com/xeipuuv/gojsonschema"
Expand Down Expand Up @@ -314,12 +315,12 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
}

type StoredNonce struct {
ID uuid.UUID `db:"id"`
Nonce string `db:"nonce"`
Address string `db:"address"` // Optional: can be empty until signature verification
CreatedAt time.Time `db:"created_at"`
ExpiresAt time.Time `db:"expires_at"`
Used bool `db:"used"`
ID uuid.UUID `db:"id"`
Nonce string `db:"nonce"`
Address sql.NullString `db:"address"` // Changed this line
CreatedAt time.Time `db:"created_at"`
ExpiresAt time.Time `db:"expires_at"`
Used bool `db:"used"`
}

const NonceExpiration = 5 * time.Minute
Expand All @@ -329,7 +330,7 @@ func (a *API) GetNonce(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)

nonce := crypto.SecureToken()
nonce := crypto.GenerateOtp(12)

storedNonce := &StoredNonce{
ID: uuid.Must(uuid.NewV4()),
Expand All @@ -353,28 +354,38 @@ func (a *API) GetNonce(w http.ResponseWriter, r *http.Request) error {
return internalServerError("Error storing nonce").WithInternalError(err)
}

log.Printf("Generated nonce: %s", nonce)


return sendJSON(w, http.StatusOK, map[string]interface{}{
"nonce": nonce,
"expiresAt": storedNonce.ExpiresAt,
})
}

func (a *API) verifyAndConsumeNonce(ctx context.Context, nonce string, address string) error {
db := a.db.WithContext(ctx)

var storedNonce StoredNonce
err := db.Transaction(func(tx *storage.Connection) error {
// Find the nonce
err := tx.TX.QueryRow(`
SELECT id, nonce, address, created_at, expires_at, used
FROM auth.nonces
WHERE nonce = $1 AND used = false
`, nonce).Scan(&storedNonce.ID, &storedNonce.Nonce,
&storedNonce.Address, &storedNonce.CreatedAt,
&storedNonce.ExpiresAt, &storedNonce.Used)
if err != nil {
return err
}
log.Printf("Starting nonce verification for: %s", nonce)
db := a.db.WithContext(ctx)

var storedNonce StoredNonce
err := db.Transaction(func(tx *storage.Connection) error {
// Find the nonce
log.Printf("Executing query for nonce: %s", nonce)
err := tx.TX.QueryRow(`
SELECT id, nonce, address, created_at, expires_at, used
FROM auth.nonces
WHERE nonce = $1 AND used = false
`, nonce).Scan(&storedNonce.ID, &storedNonce.Nonce,
&storedNonce.Address, &storedNonce.CreatedAt,
&storedNonce.ExpiresAt, &storedNonce.Used)
if err != nil {
log.Printf("Error scanning nonce: %v", err)
return err
}

log.Printf("Found nonce in DB: %+v", storedNonce)


// Check expiration
if time.Now().After(storedNonce.ExpiresAt) {
Expand All @@ -386,7 +397,7 @@ func (a *API) verifyAndConsumeNonce(ctx context.Context, nonce string, address s
UPDATE auth.nonces
SET used = true, address = $1
WHERE id = $2
`, address, storedNonce.ID)
`, sql.NullString{String: address, Valid: true}, storedNonce.ID)
return err
})

Expand All @@ -402,8 +413,16 @@ func (a *API) Web3Grant(ctx context.Context, w http.ResponseWriter, r *http.Requ
return err
}

parsedMessage, err := siws.ParseSIWSMessage(params.Message)

if err != nil {
return siws.ErrorMalformedMessage
}



// Verify and consume nonce first
if err := a.verifyAndConsumeNonce(ctx, params.Nonce, params.Address); err != nil {
if err := a.verifyAndConsumeNonce(ctx, parsedMessage.Nonce, parsedMessage.Address); err != nil {
return siws.ErrorCodeInvalidNonce
}

Expand Down
64 changes: 53 additions & 11 deletions internal/crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"math"
"math/big"
"net/url"
Expand Down Expand Up @@ -174,96 +175,137 @@ func VerifySIWS(
msg *siws.SIWSMessage,
params siws.SIWSVerificationParams,
) error {
log.Printf("[DEBUG] Starting SIWS verification - Signature length: %d", len(signature))

// 1) Basic input validation
if rawMessage == "" {
log.Printf("[ERROR] Empty raw message")
return siws.ErrEmptyRawMessage
}
if len(signature) == 0 {
log.Printf("[ERROR] Empty signature")
return siws.ErrEmptySignature
}
if msg == nil {
log.Printf("[ERROR] Nil message")
return siws.ErrNilMessage
}

log.Printf("[DEBUG] Basic validation passed - Message length: %d", len(rawMessage))

// 2) Domain validation
log.Printf("[DEBUG] Validating domain - Expected: %s, Actual: %s", params.ExpectedDomain, msg.Domain)
if params.ExpectedDomain == "" {
log.Printf("[ERROR] Missing expected domain")
return siws.ErrMissingDomain
}
if !siws.IsValidDomain(msg.Domain) {
log.Printf("[ERROR] Invalid domain format: %s", msg.Domain)
return siws.ErrInvalidDomainFormat
}
if msg.Domain != params.ExpectedDomain {
log.Printf("[ERROR] Domain mismatch - Expected: %s, Got: %s", params.ExpectedDomain, msg.Domain)
return siws.ErrDomainMismatch
}

// 3) Address/Public Key validation (combined checks)
// 3) Address/Public Key validation
pubKey := base58.Decode(msg.Address)
log.Printf("[DEBUG] Validating public key - Address: %s, Decoded length: %d", msg.Address, len(pubKey))
if !siws.IsBase58PubKey(pubKey) {
log.Printf("[ERROR] Invalid public key size: %d", len(pubKey))
return siws.ErrInvalidPubKeySize
}

// 4) Version validation
log.Printf("[DEBUG] Checking version: %s", msg.Version)
if msg.Version != "1" {
log.Printf("[ERROR] Invalid version: %s", msg.Version)
return siws.ErrInvalidVersion
}

// 5) Chain ID validation (using helper)
// 5) Chain ID validation
if msg.ChainID != "" {
if !siws.IsValidSolanaNetwork(msg.ChainID) {

log.Printf("[DEBUG] Validating chain ID: %s", msg.ChainID)
if !siws.IsValidSolanaNetwork(msg.ChainID) {
log.Printf("[ERROR] Invalid chain ID: %s", msg.ChainID)
return siws.ErrInvalidChainID
}
}

// 6) Nonce validation (consolidated)
// 6) Nonce validation
if msg.Nonce != "" {
log.Printf("[DEBUG] Checking nonce length: %d", len(msg.Nonce))
if len(msg.Nonce) < 8 {
log.Printf("[ERROR] Nonce too short: %d chars", len(msg.Nonce))
return siws.ErrNonceTooShort
}
}

// 7) URI and Resources validation
// 7) URI validation
if msg.URI != "" {
log.Printf("[DEBUG] Validating URI: %s", msg.URI)
if _, err := url.Parse(msg.URI); err != nil {
log.Printf("[ERROR] Invalid URI: %s - %v", msg.URI, err)
return siws.ErrInvalidURI
}
}

// Resources validation
for _, resource := range msg.Resources {
log.Printf("[DEBUG] Validating resource URI: %s", resource)
if _, err := url.Parse(resource); err != nil {
log.Printf("[ERROR] Invalid resource URI: %s - %v", resource, err)
return siws.ErrInvalidResourceURI
}
}

// 8) Signature verification
log.Printf("[DEBUG] Verifying ed25519 signature")
log.Printf("[DEBUG] Verification inputs - Message bytes: %v", []byte(rawMessage))
log.Printf("[DEBUG] Verification inputs - Signature bytes: %v", signature)
if !ed25519.Verify(pubKey, []byte(rawMessage), signature) {
log.Printf("[ERROR] Signature verification failed")
return siws.ErrSignatureVerification
}

// 9) Time validations (consolidated)
// 9) Time validations
now := time.Now().UTC()
log.Printf("[DEBUG] Time validation - Current time: %s", now)

if !msg.IssuedAt.IsZero() {
log.Printf("[DEBUG] Checking issuedAt: %s", msg.IssuedAt)
if now.Before(msg.IssuedAt) {
log.Printf("[ERROR] Message from future - IssuedAt: %s", msg.IssuedAt)
return siws.ErrFutureMessage
}

if params.CheckTime && params.TimeDuration > 0 {
expiry := msg.IssuedAt.Add(params.TimeDuration)
log.Printf("[DEBUG] Checking message expiry - Expiry: %s", expiry)
if now.After(expiry) {
log.Printf("[ERROR] Message expired - Expiry: %s", expiry)
return siws.ErrMessageExpired
}
}
}

if !msg.NotBefore.IsZero() && now.Before(msg.NotBefore) {
return siws.ErrNotYetValid
if !msg.NotBefore.IsZero() {
log.Printf("[DEBUG] Checking notBefore: %s", msg.NotBefore)
if now.Before(msg.NotBefore) {
log.Printf("[ERROR] Message not yet valid - NotBefore: %s", msg.NotBefore)
return siws.ErrNotYetValid
}
}

if !msg.ExpirationTime.IsZero() && now.After(msg.ExpirationTime) {
return siws.ErrMessageExpired
if !msg.ExpirationTime.IsZero() {
log.Printf("[DEBUG] Checking expirationTime: %s", msg.ExpirationTime)
if now.After(msg.ExpirationTime) {
log.Printf("[ERROR] Message expired - ExpirationTime: %s", msg.ExpirationTime)
return siws.ErrMessageExpired
}
}

log.Printf("[INFO] SIWS verification successful")
return nil
}

Expand Down
2 changes: 0 additions & 2 deletions internal/reloader/testdata/50_example.env
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI="http://localhost:9999/callback"

# EIP-4361 OAuth config
GOTRUE_EXTERNAL_WEB3_ENABLED="true"
GOTRUE_EXTERNAL_WEB3_STATEMENT="Sign this message to verify your identity"
GOTRUE_EXTERNAL_WEB3_VERSION="1"
GOTRUE_EXTERNAL_WEB3_TIMEOUT="300s"
GOTRUE_EXTERNAL_WEB3_DOMAIN="localhost:9999"

Expand Down
9 changes: 5 additions & 4 deletions internal/utilities/solana/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ var (
ErrorCodeInvalidSignature = NewSIWSError("invalid signature", http.StatusBadRequest)
ErrorMalformedMessage = NewSIWSError("malformed message", http.StatusBadRequest)
ErrInvalidDomainFormat = NewSIWSError("invalid domain format", http.StatusBadRequest)
ErrMessageDomainMismatch = NewSIWSError("domain's header domain and body domain are mismatched.", http.StatusBadRequest)
ErrInvalidStatementFormat = NewSIWSError("invalid statement format", http.StatusBadRequest)
ErrInvalidIssuedAtFormat = NewSIWSError("invalid issued at format", http.StatusBadRequest)
ErrInvalidExpirationTimeFormat = NewSIWSError("invalid expiration time format", http.StatusBadRequest)
Expand All @@ -73,10 +74,10 @@ func GenerateNonce() (string, error) {
// ValidateDomain checks if a domain is valid or not. This can be expanded with
// real domain validation logic. Here, we do a simple parse check.
func IsValidDomain(domain string) bool {
// Regular expression to validate domain name
regex := `^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`
match, _ := regexp.MatchString(regex, domain)
return match
// Regular expression to validate domain name including localhost and ports
regex := `^(localhost|(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,})(?::\d{1,5})?$`
match, _ := regexp.MatchString(regex, domain)
return match
}

// IsBase58PubKey checks if the input is a plausible base58 Solana public key.
Expand Down
Loading

0 comments on commit efb21e7

Please sign in to comment.