Skip to content

Commit

Permalink
feat: add custom access token hook (#1332)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Allow developers to customize access token by adding claims or altering
metadata

Other affected changes:
- `generateAccessToken` is moved to a method on `API` and now takes in a
`context`. In tests we use a default `context.Background()`

---------

Co-authored-by: joel@joellee.org <joel@joellee.org>
Co-authored-by: Kang Ming <kang.ming1996@gmail.com>
  • Loading branch information
3 people authored Dec 6, 2023
1 parent 4254873 commit 312f871
Show file tree
Hide file tree
Showing 16 changed files with 332 additions and 21 deletions.
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ require (

require (
github.com/bits-and-blooms/bitset v1.10.0 // indirect
github.com/bits-and-blooms/bloom/v3 v3.6.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.1 // indirect
github.com/gobuffalo/nulls v0.4.2 // indirect
github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
)

require (
Expand All @@ -66,12 +66,15 @@ require (
)

require (
github.com/bits-and-blooms/bloom/v3 v3.6.0
github.com/crewjam/saml v0.4.14
github.com/deepmap/oapi-codegen v1.12.4
github.com/fatih/structs v1.1.0
github.com/gobuffalo/pop/v6 v6.1.1
github.com/jackc/pgx/v4 v4.17.2
github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869
github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747
github.com/xeipuuv/gojsonschema v1.2.0
)

require (
Expand Down
7 changes: 7 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,14 @@ github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 h1:VDuRtwen5Z7QQ5ctu
github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869/go.mod h1:eHX5nlSMSnyPjUrbYzeqrA8snCe2SKyfizKjU3dkfOw=
github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747 h1:FIUdLV4o5JLsJno4Poum157kAxDKINeJo6liBfauLrI=
github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747/go.mod h1:kWsnmPfUBZTavlXYkfJrE9unzmmRAIi/kqsxXfEWEY8=
github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg=
github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
Expand Down
3 changes: 2 additions & 1 deletion internal/api/audit_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -48,7 +49,7 @@ func (ts *AuditTestSuite) makeSuperAdmin(email string) string {
u.Role = "supabase_admin"

var token string
token, _, err = generateAccessToken(ts.API.db, u, nil, &ts.Config.JWT)
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant)
require.NoError(ts.T(), err, "Error generating access token")

p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}}
Expand Down
3 changes: 2 additions & 1 deletion internal/api/invite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -59,7 +60,7 @@ func (ts *InviteTestSuite) makeSuperAdmin(email string) string {
u.Role = "supabase_admin"

var token string
token, _, err = generateAccessToken(ts.API.db, u, nil, &ts.Config.JWT)
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.Invite)

require.NoError(ts.T(), err, "Error generating access token")

Expand Down
3 changes: 2 additions & 1 deletion internal/api/logout_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -42,7 +43,7 @@ func (ts *LogoutTestSuite) SetupTest() {

// generate access token to use for logout
var t string
t, _, err = generateAccessToken(ts.API.db, u, nil, &ts.Config.JWT)
t, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant)
require.NoError(ts.T(), err)
ts.token = t
}
Expand Down
39 changes: 39 additions & 0 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,45 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error {
}

return nil
case *hooks.CustomAccessTokenInput:
hookOutput, ok := output.(*hooks.CustomAccessTokenOutput)
if !ok {
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runHook(ctx, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}
if err := validateTokenClaims(hookOutput.Claims); err != nil {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: err.Error(),
}

return httpError
}
return nil

default:
panic("unknown hook input type")
Expand Down
4 changes: 2 additions & 2 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (ts *MFATestSuite) SetupTest() {
}

func (ts *MFATestSuite) generateToken(user *models.User, sessionId *uuid.UUID) string {
token, _, err := generateAccessToken(ts.API.db, user, sessionId, &ts.Config.JWT)
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, sessionId, models.TOTPSignIn)
require.NoError(ts.T(), err, "Error generating access token")
return token
}
Expand All @@ -96,7 +96,7 @@ func (ts *MFATestSuite) TestEnrollFactor() {
testFriendlyName := "bob"
alternativeFriendlyName := "john"

token, _, err := generateAccessToken(ts.API.db, ts.TestUser, nil, &ts.Config.JWT)
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn)
require.NoError(ts.T(), err)

var cases = []struct {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/phone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() {
require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user")

var token string
token, _, err = generateAccessToken(ts.API.db, u, nil, &ts.Config.JWT)
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.OTP)
require.NoError(ts.T(), err)

cases := []struct {
Expand Down
67 changes: 57 additions & 10 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"strconv"
"time"

"fmt"
"github.com/gofrs/uuid"
"github.com/golang-jwt/jwt"
"github.com/xeipuuv/gojsonschema"

"github.com/supabase/gotrue/internal/conf"
"github.com/supabase/gotrue/internal/hooks"
Expand Down Expand Up @@ -285,7 +287,8 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
return sendJSON(w, http.StatusOK, token)
}

func generateAccessToken(tx *storage.Connection, user *models.User, sessionId *uuid.UUID, config *conf.JWTConfiguration) (string, int64, error) {
func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
config := a.config
aal, amr := models.AAL1.String(), []models.AMREntry{}
sid := ""
if sessionId != nil {
Expand All @@ -301,15 +304,15 @@ func generateAccessToken(tx *storage.Connection, user *models.User, sessionId *u
}

issuedAt := time.Now().UTC()
expiresAt := issuedAt.Add(time.Second * time.Duration(config.Exp)).Unix()
expiresAt := issuedAt.Add(time.Second * time.Duration(config.JWT.Exp)).Unix()

claims := &GoTrueClaims{
claims := &hooks.GoTrueClaims{
StandardClaims: jwt.StandardClaims{
Subject: user.ID.String(),
Audience: user.Aud,
IssuedAt: issuedAt.Unix(),
ExpiresAt: expiresAt,
Issuer: config.Issuer,
Issuer: config.JWT.Issuer,
},
Email: user.GetEmail(),
Phone: user.GetPhone(),
Expand All @@ -321,17 +324,37 @@ func generateAccessToken(tx *storage.Connection, user *models.User, sessionId *u
AuthenticationMethodReference: amr,
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
var token *jwt.Token
if config.Hook.CustomAccessToken.Enabled {
input := hooks.CustomAccessTokenInput{
UserID: user.ID,
Claims: claims,
AuthenticationMethod: authenticationMethod.String(),
}

output := hooks.CustomAccessTokenOutput{}

err := a.invokeHook(ctx, &input, &output)
if err != nil {
return "", 0, err
}
goTrueClaims := jwt.MapClaims(output.Claims)

if config.KeyID != "" {
token = jwt.NewWithClaims(jwt.SigningMethodHS256, goTrueClaims)

} else {
token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
}

if config.JWT.KeyID != "" {
if token.Header == nil {
token.Header = make(map[string]interface{})
}

token.Header["kid"] = config.KeyID
token.Header["kid"] = config.JWT.KeyID
}

signed, err := token.SignedString([]byte(config.Secret))
signed, err := token.SignedString([]byte(config.JWT.Secret))
if err != nil {
return "", 0, err
}
Expand Down Expand Up @@ -362,7 +385,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u
return terr
}

tokenString, expiresAt, terr = generateAccessToken(tx, user, refreshToken.SessionId, &config.JWT)
tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, refreshToken.SessionId, authenticationMethod)
if terr != nil {
return internalServerError("error generating jwt token").WithInternalError(terr)
}
Expand Down Expand Up @@ -422,7 +445,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
return err
}

tokenString, expiresAt, terr = generateAccessToken(tx, user, &sessionId, &config.JWT)
tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn)

if terr != nil {
return internalServerError("error generating jwt token").WithInternalError(terr)
Expand Down Expand Up @@ -494,3 +517,27 @@ func (a *API) clearCookieToken(config *conf.GlobalConfiguration, name string, w
Domain: config.Cookie.Domain,
})
}

func validateTokenClaims(outputClaims map[string]interface{}) error {
schemaLoader := gojsonschema.NewStringLoader(hooks.MinimumViableTokenSchema)

documentLoader := gojsonschema.NewGoLoader(outputClaims)

result, err := gojsonschema.Validate(schemaLoader, documentLoader)
if err != nil {
return err
}

if !result.Valid() {
var errorMessages string

for _, desc := range result.Errors() {
errorMessages += fmt.Sprintf("- %s\n", desc)
fmt.Printf("- %s\n", desc)
}
return fmt.Errorf("output claims do not conform to the expected schema: \n%s", errorMessages)

}

return nil
}
2 changes: 1 addition & 1 deletion internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
issuedToken = newToken
}

tokenString, expiresAt, terr = generateAccessToken(tx, user, issuedToken.SessionId, &config.JWT)
tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, issuedToken.SessionId, models.TokenRefresh)
if terr != nil {
return internalServerError("error generating jwt token").WithInternalError(terr)
}
Expand Down
Loading

0 comments on commit 312f871

Please sign in to comment.