From 73ba37d15135657013ed9441eb2c010908ff184e Mon Sep 17 00:00:00 2001 From: Vivek Shankar Date: Sat, 5 Aug 2023 12:08:53 +0000 Subject: [PATCH 1/5] feat: [#631] JWT encryption and context-based strategy --- token/jwt/strategy_jwt.go | 180 +++++++++++++++++++++++++++++++++ token/jwt/strategy_jwt_test.go | 87 ++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 token/jwt/strategy_jwt.go create mode 100644 token/jwt/strategy_jwt_test.go diff --git a/token/jwt/strategy_jwt.go b/token/jwt/strategy_jwt.go new file mode 100644 index 000000000..070131abb --- /dev/null +++ b/token/jwt/strategy_jwt.go @@ -0,0 +1,180 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/rsa" + "fmt" + + "github.com/go-jose/go-jose/v3" +) + +// KeyContext contains context that is used to sign, validation, encrypt and decrypt tokens. +// It is populated in different ways depending on the operation. For example - +// +// 1. Validate : the SigningKeyID and SigningAlgorithm is based on the JWT header of the incoming token +// 2. Decrypt : the EncryptionKeyID, EncryptionAlgorithm and EncryptionContentAlgorithm is based on the JWT header of the incoming token +// 3. Generate : all the properties may be populated. The JWT strategy implementation may sign the token, then optionally encrypt it +type KeyContext struct { + SigningKeyID string + SigningAlgorithm string + EncryptionKeyID string + EncryptionAlgorithm string + EncryptionContentAlgorithm string + Extra map[string]interface{} +} + +// Strategy provides the overall strategy interface to sign (generate), encrypt (part of generate), decrypt and validate JWTs. +type Strategy interface { + Signer + + // GenerateWithSettings signs and optionally encrypts the token based on the context provided + GenerateWithSettings(ctx context.Context, settings *KeyContext, claims MapClaims, header Mapper) (string, string, error) + + // DecryptWithSettings decrypts the token provided. If the token is not encrypted, the function should return an error. + DecryptWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) + + // ValidateWithSettings validates the signed token. If the token is not signed, the function should return an error. + ValidateWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) +} + +type GetPrivateKeyWithContextFunc func(ctx context.Context, context *KeyContext) (interface{}, error) + +// DefaultStrategy is responsible for generating (signing and optionally encrypting), decrypting and validating JWT challenges +type DefaultStrategy struct { + *DefaultSigner + GetPrivateKey GetPrivateKeyWithContextFunc +} + +func NewDefaultStrategy(GetPrivateKey GetPrivateKeyWithContextFunc) Strategy { + return &DefaultStrategy{ + DefaultSigner: &DefaultSigner{ + GetPrivateKey: func(ctx context.Context) (interface{}, error) { + return GetPrivateKey(ctx, nil) + }, + }, + GetPrivateKey: GetPrivateKey, + } +} + +// GenerateWithSettings signs and optionally encrypts the token based on the context provided +func (s *DefaultStrategy) GenerateWithSettings(ctx context.Context, settings *KeyContext, claims MapClaims, header Mapper) (string, string, error) { + // ignoring the signing alg and kid for this implementation and just using the DefaultSigner implementation + rawToken, sig, err := s.DefaultSigner.Generate(ctx, claims, header) + if err != nil { + return "", "", err + } + + if settings.EncryptionAlgorithm == "" { + return rawToken, sig, err + } + + key, err := s.GetPrivateKey(ctx, settings) + if err != nil { + return "", "", err + } + + if t, ok := key.(*jose.JSONWebKey); ok { + key = t.Key + } + + var pubKey interface{} + switch t := key.(type) { + case *rsa.PrivateKey: + pubKey = &t.PublicKey + case *ecdsa.PrivateKey: + pubKey = &t.PublicKey + case jose.OpaqueSigner: + pubKey = t.Public() + default: + return "", "", fmt.Errorf("unable to decode token. Invalid PrivateKey type %T", key) + } + + eo := &jose.EncrypterOptions{} + eo = eo.WithContentType("JWT").WithType("JWT") + enc, err := jose.NewEncrypter( + jose.ContentEncryption(settings.EncryptionContentAlgorithm), + jose.Recipient{ + Algorithm: jose.KeyAlgorithm(settings.EncryptionAlgorithm), + Key: pubKey, + KeyID: settings.EncryptionKeyID, + }, + eo) + + if err != nil { + return "", "", fmt.Errorf("unable to build encrypter; err=%v", err) + } + + // Encrypt the token + o, err := enc.Encrypt([]byte(rawToken)) + if err != nil { + return "", "", fmt.Errorf("encrypting the token failed. err=%v", err) + } + + // Serialize the encrypted token + rawToken, err = o.CompactSerialize() + if err != nil { + return "", "", fmt.Errorf("serializing the encrypted token failed. err=%v", err) + } + + return rawToken, sig, err +} + +// DecryptWithSettings decrypts the token provided. If the token is not encrypted, the function should return an error. +func (s *DefaultStrategy) DecryptWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) { + + parsedToken, err := jose.ParseEncrypted(token) + if err != nil { + return "", fmt.Errorf("unable to parse the token") + } + + if settings == nil { + h := parsedToken.Header + enc, _ := h.ExtraHeaders[jose.HeaderKey("enc")].(string) + settings = &KeyContext{ + EncryptionKeyID: h.KeyID, + EncryptionAlgorithm: h.Algorithm, + EncryptionContentAlgorithm: enc, + } + } + + key, err := s.GetPrivateKey(ctx, settings) + var privateKey interface{} + switch t := key.(type) { + case *jose.JSONWebKey: + privateKey = t.Key + case jose.JSONWebKey: + privateKey = t.Key + case *rsa.PrivateKey: + privateKey = t + case *ecdsa.PrivateKey: + privateKey = t + case jose.OpaqueSigner: + switch tt := t.Public().Key.(type) { + case *rsa.PrivateKey: + privateKey = t + case *ecdsa.PrivateKey: + privateKey = t + default: + return "", fmt.Errorf("unsupported private / public key pairs: %T, %T", t, tt) + } + default: + return "", fmt.Errorf("unsupported private key type: %T", t) + } + + decrypted, err := parsedToken.Decrypt(privateKey) + if err != nil { + return "", err + } + + return string(decrypted), nil +} + +// ValidateWithSettings validates the signed token. If the token is not signed, the function should return an error. +func (s *DefaultStrategy) ValidateWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) { + // ignoring the signing alg and kid for this implementation and just using the DefaultSigner implementation + return s.DefaultSigner.Validate(ctx, token) +} diff --git a/token/jwt/strategy_jwt_test.go b/token/jwt/strategy_jwt_test.go new file mode 100644 index 000000000..395fd3678 --- /dev/null +++ b/token/jwt/strategy_jwt_test.go @@ -0,0 +1,87 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package jwt + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ory/fosite/internal/gen" + "github.com/stretchr/testify/require" +) + +func TestEncryptJWT(t *testing.T) { + key := gen.MustRSAKey() + encryptKey := gen.MustRSAKey() + for k, tc := range []struct { + d string + keyContext *KeyContext + strategy Strategy + resetKey func(strategy Strategy) + }{ + { + d: "SameKeyStrategy", + keyContext: &KeyContext{ + EncryptionAlgorithm: "RSA-OAEP", + EncryptionContentAlgorithm: "A256GCM", + EncryptionKeyID: "samekey", + }, + strategy: NewDefaultStrategy(func(_ context.Context, context *KeyContext) (interface{}, error) { + return key, nil + }), + resetKey: func(strategy Strategy) { + key = gen.MustRSAKey() + }, + }, + { + d: "EncryptionKeyStrategy", + keyContext: &KeyContext{ + EncryptionAlgorithm: "RSA-OAEP", + EncryptionContentAlgorithm: "A256GCM", + EncryptionKeyID: "enc_key", + }, + strategy: NewDefaultStrategy(func(_ context.Context, context *KeyContext) (interface{}, error) { + if context == nil { + return key, nil + } + + if context.EncryptionKeyID == "enc_key" { + return encryptKey, nil + } + + return key, nil + }), + resetKey: func(strategy Strategy) { + key = gen.MustRSAKey() + encryptKey = gen.MustRSAKey() + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { + ctx := context.Background() + + // Reset private key + tc.resetKey(tc.strategy) + + claims := &JWTClaims{ + ExpiresAt: time.Now().UTC().Add(time.Hour), + } + + token, sig, err := tc.strategy.GenerateWithSettings(ctx, tc.keyContext, claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token, "Token could not be generated") + + signedToken, err := tc.strategy.DecryptWithSettings(ctx, tc.keyContext, token) + require.NoError(t, err) + require.NotNil(t, signedToken, "Token could not be decrypted; token=%s", token) + + derivedSig, err := tc.strategy.Validate(ctx, signedToken) + require.NoError(t, err) + + require.EqualValues(t, sig, derivedSig, "Signature does not match") + }) + } +} From adca18baf9feb9357c3851b8f882bd836124f99a Mon Sep 17 00:00:00 2001 From: Vivek Shankar Date: Sat, 5 Aug 2023 14:57:34 +0000 Subject: [PATCH 2/5] feat: [#631] Support for client_assertion JWE --- client.go | 11 ++ client_authentication.go | 129 +++++++------- client_authentication_test.go | 67 ++++++- fosite.go | 2 + fosite_jwt.go | 320 ++++++++++++++++++++++++++++++++++ 5 files changed, 457 insertions(+), 72 deletions(-) create mode 100644 fosite_jwt.go diff --git a/client.go b/client.go index 9b4ce02cb..427fae438 100644 --- a/client.go +++ b/client.go @@ -71,6 +71,17 @@ type OpenIDConnectClient interface { GetTokenEndpointAuthSigningAlgorithm() string } +// ClientWithAllowedVerificationKeys adds a security control to the client configuration to only allow +// specific verification keys. This ensures that a key that is valid for client X can't be used for client Y +// unless allowed. This becomes especially important for cases where the clients are controlled by third-parties +// and are issued specific keys from a central organization, which may be the OP's org or a central regulatory authority, +// and the security controls of the clients cannot be guaranteed. +type ClientWithAllowedVerificationKeys interface { + // AllowedVerificationKeys provides a list of key IDs that can be used in the JWT + // header for private_key_jwt authentication and for JWT bearer grant flow + AllowedVerificationKeys() []string +} + // ResponseModeClient represents a client capable of handling response_mode type ResponseModeClient interface { // GetResponseMode returns the response modes that client is allowed to send diff --git a/client_authentication.go b/client_authentication.go index 685e0311d..f80c19305 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -8,7 +8,6 @@ import ( "crypto/ecdsa" "crypto/rsa" "encoding/json" - "fmt" "net/http" "net/url" "time" @@ -16,7 +15,6 @@ import ( "github.com/ory/x/errorsx" "github.com/go-jose/go-jose/v3" - "github.com/pkg/errors" "github.com/ory/fosite/token/jwt" ) @@ -54,7 +52,7 @@ func (f *Fosite) findClientPublicJWK(ctx context.Context, oidcClient OpenIDConne } // AuthenticateClient authenticates client requests using the configured strategy -// `Fosite.ClientAuthenticationStrategy`, if nil it uses `Fosite.DefaultClientAuthenticationStrategy` +// `ClientAuthenticationStrategy`, if nil it uses `DefaultClientAuthenticationStrategy` func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form url.Values) (Client, error) { if s := f.Config.GetClientAuthenticationStrategy(ctx); s != nil { return s(ctx, r, form) @@ -71,81 +69,80 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The client_assertion request parameter must be set when using client_assertion_type of '%s'.", clientAssertionJWTBearerType)) } - var clientID string - var client Client - - token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) { - var err error - clientID, _, err = clientCredentialsFromRequestBody(form, false) - if err != nil { - return nil, err + // for backward compatibility + if f.JWTHelper == nil { + f.JWTHelper = &JWTHelper{ + JWTStrategy: nil, + Config: f.Config, } + } - if clientID == "" { - claims := t.Claims - if sub, ok := claims["sub"].(string); !ok { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined.")) - } else { - clientID = sub - } - } + // Parse the assertion + token, parsedToken, isJWE, err := f.newToken(assertion, "client_assertion", ErrInvalidClient) + if err != nil { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to parse the client_assertion").WithWrap(err).WithDebug(err.Error())) + } - client, err = f.Store.GetClient(ctx, clientID) - if err != nil { - return nil, errorsx.WithStack(ErrInvalidClient.WithWrap(err).WithDebug(err.Error())) - } + claims := token.Claims - oidcClient, ok := client.(OpenIDConnectClient) - if !ok { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The server configuration does not support OpenID Connect specific authentication methods.")) - } + // Validate client + clientID, _, err := clientCredentialsFromRequestBody(form, false) + if err != nil { + return nil, err + } - switch oidcClient.GetTokenEndpointAuthMethod() { - case "private_key_jwt": - break - case "none": - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.")) - case "client_secret_post": - fallthrough - case "client_secret_basic": - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", oidcClient.GetTokenEndpointAuthMethod())) - case "client_secret_jwt": - fallthrough - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", oidcClient.GetTokenEndpointAuthMethod())) + if clientID == "" { + if isJWE { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The 'client_id' must be part of the request when encrypted client_assertion is used.")) } - if oidcClient.GetTokenEndpointAuthSigningAlgorithm() != fmt.Sprintf("%s", t.Header["alg"]) { - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' uses signing algorithm '%s' but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", t.Header["alg"], oidcClient.GetTokenEndpointAuthSigningAlgorithm())) + if sub, ok := claims["sub"].(string); !ok { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined.")) + } else { + clientID = sub } - switch t.Method { - case jose.RS256, jose.RS384, jose.RS512: - return f.findClientPublicJWK(ctx, oidcClient, t, true) - case jose.ES256, jose.ES384, jose.ES512: - return f.findClientPublicJWK(ctx, oidcClient, t, false) - case jose.PS256, jose.PS384, jose.PS512: - return f.findClientPublicJWK(ctx, oidcClient, t, true) - case jose.HS256, jose.HS384, jose.HS512: - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This authorization server does not support client authentication method 'client_secret_jwt'.")) - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"])) - } - }) + } + + client, err := f.Store.GetClient(ctx, clientID) if err != nil { - // Do not re-process already enhanced errors - var e *jwt.ValidationError - if errors.As(err, &e) { - if e.Inner != nil { - return nil, e.Inner - } - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the integrity of the 'client_assertion' value.").WithWrap(err).WithDebug(err.Error())) - } + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client could not be authenticated.").WithWrap(err).WithDebug(err.Error())) + } + + oidcClient, ok := client.(OpenIDConnectClient) + if !ok { + return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The server configuration does not support OpenID Connect specific authentication methods.")) + } + + switch oidcClient.GetTokenEndpointAuthMethod() { + case "private_key_jwt": + break + case "none": + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.")) + case "client_secret_post": + fallthrough + case "client_secret_basic": + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", oidcClient.GetTokenEndpointAuthMethod())) + case "client_secret_jwt": + fallthrough + default: + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", oidcClient.GetTokenEndpointAuthMethod())) + } + + // Validate signature + if !isJWE && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != "" && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != parsedToken.Headers[0].Algorithm { + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The client_assertion uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", parsedToken.Headers[0].Algorithm, oidcClient.GetTokenEndpointAuthSigningAlgorithm())) + } + + if token, parsedToken, err = f.ValidateParsedAssertionWithClient(ctx, "client_assertion", assertion, token, parsedToken, oidcClient, false, ErrInvalidClient); err != nil { return nil, err - } else if err := token.Claims.Valid(); err != nil { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error())) } - claims := token.Claims + if isJWE && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != "" && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != parsedToken.Headers[0].Algorithm { + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The client_assertion uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", parsedToken.Headers[0].Algorithm, oidcClient.GetTokenEndpointAuthSigningAlgorithm())) + } + + claims = token.Claims + var jti string if !claims.VerifyIssuer(clientID, true) { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) diff --git a/client_authentication_test.go b/client_authentication_test.go index c93073ddc..c8f951629 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -31,6 +31,31 @@ import ( "github.com/ory/fosite/storage" ) +func encryptAssertionWithRSAKey(t *testing.T, token string, pubKey *rsa.PublicKey) string { + eo := &jose.EncrypterOptions{} + eo = eo.WithContentType("JWT").WithType("JWT") + enc, err := jose.NewEncrypter( + jose.ContentEncryption("A256GCM"), + jose.Recipient{ + Algorithm: jose.KeyAlgorithm("RSA-OAEP"), + Key: pubKey, + KeyID: "enc_key", + }, + eo) + + require.NoError(t, err, "unable to build encrypter; err=%v", err) + + // Encrypt the token + o, err := enc.Encrypt([]byte(token)) + require.NoError(t, err, "encrypting the token failed. err=%v", err) + + // Serialize the encrypted token + token, err = o.CompactSerialize() + require.NoError(t, err, "serializing the encrypted token failed. err=%v", err) + + return token +} + func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.RS256, claims) token.Header["kid"] = kid @@ -75,13 +100,23 @@ func TestAuthenticateClient(t *testing.T) { const at = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" hasher := &BCrypt{Config: &Config{HashCost: 6}} + encKey := gen.MustRSAKey() + + config := &Config{ + JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), + ClientSecretsHasher: hasher, + TokenURL: "token-url", + HTTPClient: retryablehttp.NewClient(), + } + f := &Fosite{ - Store: storage.NewMemoryStore(), - Config: &Config{ - JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), - ClientSecretsHasher: hasher, - TokenURL: "token-url", - HTTPClient: retryablehttp.NewClient(), + Store: storage.NewMemoryStore(), + Config: config, + JWTHelper: &JWTHelper{ + Config: config, + JWTStrategy: jwt.NewDefaultStrategy(func(ctx context.Context, context *jwt.KeyContext) (interface{}, error) { + return encKey, nil + }), }, } @@ -300,6 +335,26 @@ func TestAuthenticateClient(t *testing.T) { }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), }, + { + d: "should pass with proper encrypted RSA assertion when JWKs are set within the client and client_id is set in the request", + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{ + "client_id": []string{"bar"}, + "client_assertion": { + encryptAssertionWithRSAKey(t, + mustGenerateRSAAssertion(t, jwt.MapClaims{ + "sub": "bar", + "exp": time.Now().Add(time.Hour).Unix(), + "iss": "bar", + "jti": "12345", + "aud": "token-url", + }, rsaKey, "kid-foo"), + &encKey.PublicKey), + }, + "client_assertion_type": []string{at}, + }, + r: new(http.Request), + }, { d: "should pass with proper ECDSA assertion when JWKs are set within the client and client_id is not set in the request", client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: ecdsaJwks, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlgorithm: "ES256"}, diff --git a/fosite.go b/fosite.go index 445213629..f73ad3a95 100644 --- a/fosite.go +++ b/fosite.go @@ -142,6 +142,8 @@ type Fosite struct { Store Storage Config Configurator + + *JWTHelper } // GetMinParameterEntropy returns MinParameterEntropy if set. Defaults to fosite.MinParameterEntropy. diff --git a/fosite_jwt.go b/fosite_jwt.go new file mode 100644 index 000000000..f3236bf95 --- /dev/null +++ b/fosite_jwt.go @@ -0,0 +1,320 @@ +package fosite + +import ( + "context" + "crypto/ecdsa" + "crypto/rsa" + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + fjwt "github.com/ory/fosite/token/jwt" + "github.com/ory/x/errorsx" +) + +// JWTHelper provides JWT helper functions that is used across +type JWTHelper struct { + // JWTStrategy is the strategy used to build and validate JWTs + JWTStrategy fjwt.Strategy + Config Configurator +} + +// ValidateParsedAssertionWithClient validates the parsed assertion based on the jwks_uri, jwks etc. configured on the client +func (f *JWTHelper) ValidateParsedAssertionWithClient(ctx context.Context, assertionType string, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, oidcClient OpenIDConnectClient, isNoneAlgAllowed bool, baseError *RFC6749Error) ( + *fjwt.Token, *jwt.JSONWebToken, error) { + jwksURI := oidcClient.GetJSONWebKeysURI() + jwks := oidcClient.GetJSONWebKeys() + allowedKeys := []string{} + if oidcClientEx, ok := oidcClient.(ClientWithAllowedVerificationKeys); ok { + allowedKeys = oidcClientEx.AllowedVerificationKeys() + } + + return f.ValidateParsedAssertion(ctx, assertionType, assertion, token, parsedToken, jwksURI, jwks, allowedKeys, isNoneAlgAllowed, baseError) +} + +// ValidateParsedAssertion validates the parsed assertion based on the jwks_uri, jwks etc. that is passed in +func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType string, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, jwksURI string, jwks *jose.JSONWebKeySet, allowedKeys []string, isNoneAlgAllowed bool, baseError *RFC6749Error) ( + *fjwt.Token, *jwt.JSONWebToken, error) { + + var err error + + if f.JWTStrategy != nil && len(token.Method) == 0 { // JWE + alg, _ := token.Header["alg"].(string) + enc, _ := token.Header["enc"].(string) + assertion, err = f.JWTStrategy.DecryptWithSettings(ctx, + &fjwt.KeyContext{ + EncryptionKeyID: parsedToken.Headers[0].KeyID, + EncryptionAlgorithm: alg, + EncryptionContentAlgorithm: enc, + }, + assertion) + if err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + var mapClaims fjwt.MapClaims = fjwt.MapClaims{} + + if cty, ok := token.Header["cty"].(string); ok && strings.ToUpper(cty) == "JWT" { // Nested JWT + + parsedToken, err = jwt.ParseSigned(assertion) + if err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + if err := parsedToken.UnsafeClaimsWithoutVerification(&mapClaims); err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + token.Claims = mapClaims + token.Method = jose.SignatureAlgorithm(parsedToken.Headers[0].Algorithm) + token.Header["kid"] = parsedToken.Headers[0].KeyID // When using jwks, the `kid` is read from token object + + } else { // Only encrypted, not signed + if err := json.Unmarshal([]byte(assertion), &mapClaims); err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + token.Claims = mapClaims + err = f.validateJWTClaims(ctx, mapClaims, assertionType, baseError) + if err != nil { + return nil, nil, err + } + + return token, parsedToken, nil + } + } + + if token.Method == fjwt.SigningMethodNone { + if !isNoneAlgAllowed { + return nil, nil, errorsx.WithStack(baseError.WithHintf("'none' is disallowed as a signing method of the '%s'.", assertionType)) + } + + return token, parsedToken, nil + } + + claims := token.Claims + if !f.verificationKeyAllowed(allowedKeys, parsedToken.Headers[0].KeyID) { + return nil, nil, errorsx.WithStack(baseError.WithHintf("The 'kid' used in the '%s' is not allowed.", assertionType)) + } + + // Validate signature + if jwksURI == "" && jwks == nil { + if f.JWTStrategy == nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + _, err := f.JWTStrategy.ValidateWithSettings(ctx, + &fjwt.KeyContext{ + SigningKeyID: parsedToken.Headers[0].KeyID, + SigningAlgorithm: parsedToken.Headers[0].Algorithm, + }, + assertion) + if err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + } else { + var key interface{} + var err error + switch token.Method { + case jose.RS256, jose.RS384, jose.RS512: + key, err = f.findPublicJWK(ctx, token, jwksURI, jwks, true, baseError) + if err != nil { + return nil, nil, wrapSigningKeyFailure( + baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) + } + case jose.ES256, jose.ES384, jose.ES512: + key, err = f.findPublicJWK(ctx, token, jwksURI, jwks, false, baseError) + if err != nil { + return nil, nil, wrapSigningKeyFailure( + baseError.WithHint("Unable to retrieve ECDSA signing key from the JSON Web Key Set."), err) + } + case jose.PS256, jose.PS384, jose.PS512: + key, err = f.findPublicJWK(ctx, token, jwksURI, jwks, true, baseError) + if err != nil { + return nil, nil, wrapSigningKeyFailure( + baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) + } + default: + return nil, nil, errorsx.WithStack(baseError.WithHintf("The '%s' uses unsupported signing algorithm '%s'.", assertionType, token.Method)) + } + + // To verify signature go-jose requires a pointer to + // public key instead of the public key value. + // The pointer values provides that pointer. + // E.g. transform rsa.PublicKey -> *rsa.PublicKey + key = pointer(key) + + // verify signature with returned key + if err := parsedToken.Claims(key, &claims); err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + } + + err = f.validateJWTClaims(ctx, claims, assertionType, baseError) + if err != nil { + return nil, nil, err + } + + return token, parsedToken, nil +} + +func (f *JWTHelper) validateJWTClaims(ctx context.Context, claims fjwt.MapClaims, assertionType string, baseError *RFC6749Error) error { + // Validate claims + // This validation is performed to be backwards compatible + // with jwt-go library behavior + if err := claims.Valid(); err != nil { + if e, ok := err.(*fjwt.ValidationError); ok { + // return a more precise error + if e.Has(fjwt.ValidationErrorExpired) { + return errorsx.WithStack(baseError.WithHintf("The '%s' has expired.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + if e.Has(fjwt.ValidationErrorIssuedAt) { + return errorsx.WithStack(baseError.WithHintf("The 'iat' claim in '%s' is in the future.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + if e.Has(fjwt.ValidationErrorNotValidYet) { + return errorsx.WithStack(baseError.WithHintf("The '%s' is not valid yet.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + } + + return errorsx.WithStack(baseError.WithHintf("Invalid claims in the '%s'.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + return nil +} + +// verificationKeyAllowed checks if the key ID is allowed +func (f *JWTHelper) verificationKeyAllowed(allowedKeys []string, kid string) bool { + if len(kid) == 0 || len(allowedKeys) == 0 { + return true // nothing to verify + } + + for _, allowedKey := range allowedKeys { + if strings.EqualFold(allowedKey, kid) { + return true // found match, the kid is allowed + } + } + + return false +} + +func (f *JWTHelper) newToken(assertion string, assertionType string, baseError *RFC6749Error) (*fjwt.Token, *jwt.JSONWebToken, bool, error) { + var err error + var parsedToken *jwt.JSONWebToken + + isJWE := false // assume it's signed + parsedToken, err = jwt.ParseSigned(assertion) + if err != nil { + parsedToken, err = jwt.ParseEncrypted(assertion) // probably it's encrypted + if err != nil { + return nil, nil, false, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + isJWE = true + } + + token := &fjwt.Token{ + Header: map[string]interface{}{}, + Method: "", + } + + if !isJWE { + var claims fjwt.MapClaims = fjwt.MapClaims{} + if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, nil, false, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + token.Claims = claims + } + + if len(parsedToken.Headers) != 1 { + return nil, nil, false, errorsx.WithStack(baseError.WithHintf("The '%s' value is expected to contain only one header.", assertionType)) + } + + // copy headers + h := parsedToken.Headers[0] + token.Header["alg"] = h.Algorithm + if h.KeyID != "" { + token.Header["kid"] = h.KeyID + } + for k, v := range h.ExtraHeaders { + token.Header[string(k)] = v + } + + if !isJWE { + token.Method = jose.SignatureAlgorithm(h.Algorithm) + } + + return token, parsedToken, isJWE, nil +} + +func (f *JWTHelper) findPublicKey(t *fjwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { + keys := set.Keys + if len(keys) == 0 { + return nil, errorsx.WithStack(baseError.WithHint("The retrieved JSON Web Key Set does not contain any keys")) + } + + kid, ok := t.Header["kid"].(string) + if ok { + keys = set.Key(kid) + } + + if len(keys) == 0 { + return nil, errorsx.WithStack(baseError.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid)) + } + + for _, key := range keys { + if key.Use != "sig" { + continue + } + if expectsRSAKey { + if k, ok := key.Key.(*rsa.PublicKey); ok { + return k, nil + } + } else { + if k, ok := key.Key.(*ecdsa.PublicKey); ok { + return k, nil + } + } + } + + if expectsRSAKey { + return nil, errorsx.WithStack(baseError.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) + } + + return nil, errorsx.WithStack(baseError.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) +} + +func (f *JWTHelper) findPublicJWK(ctx context.Context, t *fjwt.Token, jwksURI string, jwks *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { + if jwks != nil { + return f.findPublicKey(t, jwks, expectsRSAKey, baseError) + } + + keys, err := f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, false) + if err != nil { + return nil, err + } + + if key, err := f.findPublicKey(t, keys, expectsRSAKey, baseError); err == nil { + return key, nil + } + + keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, true) + if err != nil { + return nil, errorsx.WithStack(baseError.WithHintf(fmt.Sprintf("%s", err))) + } + + return f.findPublicKey(t, keys, expectsRSAKey, baseError) +} + +// if underline value of v is not a pointer +// it creates a pointer of it and returns it +func pointer(v interface{}) interface{} { + if reflect.ValueOf(v).Kind() != reflect.Ptr { + value := reflect.New(reflect.ValueOf(v).Type()) + value.Elem().Set(reflect.ValueOf(v)) + return value.Interface() + } + return v +} From f9be2260041f4c324c1e543a283b12d87834716d Mon Sep 17 00:00:00 2001 From: Vivek Shankar Date: Sat, 5 Aug 2023 15:01:01 +0000 Subject: [PATCH 3/5] fix: copyright --- fosite_jwt.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fosite_jwt.go b/fosite_jwt.go index f3236bf95..91a93a095 100644 --- a/fosite_jwt.go +++ b/fosite_jwt.go @@ -1,3 +1,6 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package fosite import ( From 7d4dfe96b7a8e39b4815e775cafc9811a60f18df Mon Sep 17 00:00:00 2001 From: Vivek Shankar Date: Sun, 6 Aug 2023 01:52:58 +0000 Subject: [PATCH 4/5] refactor: introduce a struct for JWT verification config --- client_authentication.go | 4 --- fosite_jwt.go | 61 +++++++++++++++++++++++++--------------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/client_authentication.go b/client_authentication.go index f80c19305..2bc5e03e6 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -137,10 +137,6 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, err } - if isJWE && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != "" && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != parsedToken.Headers[0].Algorithm { - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The client_assertion uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", parsedToken.Headers[0].Algorithm, oidcClient.GetTokenEndpointAuthSigningAlgorithm())) - } - claims = token.Claims var jti string diff --git a/fosite_jwt.go b/fosite_jwt.go index 91a93a095..83dc084e2 100644 --- a/fosite_jwt.go +++ b/fosite_jwt.go @@ -18,6 +18,24 @@ import ( "github.com/ory/x/errorsx" ) +// JWTValidationConfig provides configuration to validate JWTs +type JWTValidationConfig struct { + // AllowedSigningKeys are the key IDs that are allowed to verify a JWT + AllowedSigningKeys []string `json:"kids"` + + // AllowedSigningAlgs are the algorithms allowed for a signed JWT + AllowedSigningAlgs []string `json:"algs"` + + // JSONWebKeysURI is the remote URI from which the JWKS is fetched + JSONWebKeysURI string `json:"jwks_uri"` + + // JSONWebKeys in place of JSONWebKeyURI + JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` + + // NoneAlgAllowed indicates if the signing algorithm can be "none" + NoneAlgAllowed bool `json:"none"` +} + // JWTHelper provides JWT helper functions that is used across type JWTHelper struct { // JWTStrategy is the strategy used to build and validate JWTs @@ -35,11 +53,17 @@ func (f *JWTHelper) ValidateParsedAssertionWithClient(ctx context.Context, asser allowedKeys = oidcClientEx.AllowedVerificationKeys() } - return f.ValidateParsedAssertion(ctx, assertionType, assertion, token, parsedToken, jwksURI, jwks, allowedKeys, isNoneAlgAllowed, baseError) + return f.ValidateParsedAssertion(ctx, assertionType, assertion, token, parsedToken, &JWTValidationConfig{ + AllowedSigningKeys: allowedKeys, + AllowedSigningAlgs: []string{oidcClient.GetTokenEndpointAuthSigningAlgorithm()}, + JSONWebKeysURI: jwksURI, + JSONWebKeys: jwks, + NoneAlgAllowed: isNoneAlgAllowed, + }, baseError) } // ValidateParsedAssertion validates the parsed assertion based on the jwks_uri, jwks etc. that is passed in -func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType string, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, jwksURI string, jwks *jose.JSONWebKeySet, allowedKeys []string, isNoneAlgAllowed bool, baseError *RFC6749Error) ( +func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType string, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, config *JWTValidationConfig, baseError *RFC6749Error) ( *fjwt.Token, *jwt.JSONWebToken, error) { var err error @@ -89,7 +113,7 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s } if token.Method == fjwt.SigningMethodNone { - if !isNoneAlgAllowed { + if !config.NoneAlgAllowed { return nil, nil, errorsx.WithStack(baseError.WithHintf("'none' is disallowed as a signing method of the '%s'.", assertionType)) } @@ -97,12 +121,18 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s } claims := token.Claims - if !f.verificationKeyAllowed(allowedKeys, parsedToken.Headers[0].KeyID) { + signingAlg := parsedToken.Headers[0].Algorithm + if len(config.AllowedSigningAlgs) > 0 && !Arguments(config.AllowedSigningAlgs).Has(signingAlg) { + return nil, nil, errorsx.WithStack(baseError.WithHintf("The 'alg' used in the '%s' is not allowed.", assertionType)) + } + + signingKey := parsedToken.Headers[0].KeyID + if len(config.AllowedSigningKeys) > 0 && !Arguments(config.AllowedSigningKeys).Has(signingKey) { return nil, nil, errorsx.WithStack(baseError.WithHintf("The 'kid' used in the '%s' is not allowed.", assertionType)) } // Validate signature - if jwksURI == "" && jwks == nil { + if config.JSONWebKeysURI == "" && config.JSONWebKeys == nil { if f.JWTStrategy == nil { return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) } @@ -121,19 +151,19 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s var err error switch token.Method { case jose.RS256, jose.RS384, jose.RS512: - key, err = f.findPublicJWK(ctx, token, jwksURI, jwks, true, baseError) + key, err = f.findPublicJWK(ctx, token, config.JSONWebKeysURI, config.JSONWebKeys, true, baseError) if err != nil { return nil, nil, wrapSigningKeyFailure( baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) } case jose.ES256, jose.ES384, jose.ES512: - key, err = f.findPublicJWK(ctx, token, jwksURI, jwks, false, baseError) + key, err = f.findPublicJWK(ctx, token, config.JSONWebKeysURI, config.JSONWebKeys, false, baseError) if err != nil { return nil, nil, wrapSigningKeyFailure( baseError.WithHint("Unable to retrieve ECDSA signing key from the JSON Web Key Set."), err) } case jose.PS256, jose.PS384, jose.PS512: - key, err = f.findPublicJWK(ctx, token, jwksURI, jwks, true, baseError) + key, err = f.findPublicJWK(ctx, token, config.JSONWebKeysURI, config.JSONWebKeys, true, baseError) if err != nil { return nil, nil, wrapSigningKeyFailure( baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) @@ -188,21 +218,6 @@ func (f *JWTHelper) validateJWTClaims(ctx context.Context, claims fjwt.MapClaims return nil } -// verificationKeyAllowed checks if the key ID is allowed -func (f *JWTHelper) verificationKeyAllowed(allowedKeys []string, kid string) bool { - if len(kid) == 0 || len(allowedKeys) == 0 { - return true // nothing to verify - } - - for _, allowedKey := range allowedKeys { - if strings.EqualFold(allowedKey, kid) { - return true // found match, the kid is allowed - } - } - - return false -} - func (f *JWTHelper) newToken(assertion string, assertionType string, baseError *RFC6749Error) (*fjwt.Token, *jwt.JSONWebToken, bool, error) { var err error var parsedToken *jwt.JSONWebToken From 6e3faeca9b02d6c768cf9e1d5072d245f283eca8 Mon Sep 17 00:00:00 2001 From: Vivek Shankar Date: Sun, 6 Aug 2023 09:58:19 +0000 Subject: [PATCH 5/5] refactor: eliminate JWTHelper --- client_authentication.go | 79 +++---------------------------- client_authentication_test.go | 10 ++-- config.go | 6 +++ config_default.go | 9 ++++ context.go | 3 ++ fosite.go | 2 - fosite_jwt.go | 87 +++++++++++++++++++++-------------- 7 files changed, 81 insertions(+), 115 deletions(-) diff --git a/client_authentication.go b/client_authentication.go index 2bc5e03e6..975b5b633 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -5,8 +5,6 @@ package fosite import ( "context" - "crypto/ecdsa" - "crypto/rsa" "encoding/json" "net/http" "net/url" @@ -14,8 +12,6 @@ import ( "github.com/ory/x/errorsx" - "github.com/go-jose/go-jose/v3" - "github.com/ory/fosite/token/jwt" ) @@ -26,29 +22,11 @@ type ClientAuthenticationStrategy func(context.Context, *http.Request, url.Value const clientAssertionJWTBearerType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" func (f *Fosite) findClientPublicJWK(ctx context.Context, oidcClient OpenIDConnectClient, t *jwt.Token, expectsRSAKey bool) (interface{}, error) { - if set := oidcClient.GetJSONWebKeys(); set != nil { - return findPublicKey(t, set, expectsRSAKey) - } - - if location := oidcClient.GetJSONWebKeysURI(); len(location) > 0 { - keys, err := f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, false) - if err != nil { - return nil, err - } - - if key, err := findPublicKey(t, keys, expectsRSAKey); err == nil { - return key, nil - } - - keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, true) - if err != nil { - return nil, err - } - - return findPublicKey(t, keys, expectsRSAKey) + if oidcClient.GetJSONWebKeys() == nil && oidcClient.GetJSONWebKeysURI() == "" { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) } - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) + return findPublicJWK(ctx, f.Config, t, oidcClient.GetJSONWebKeysURI(), oidcClient.GetJSONWebKeys(), expectsRSAKey, ErrInvalidClient) } // AuthenticateClient authenticates client requests using the configured strategy @@ -69,16 +47,8 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The client_assertion request parameter must be set when using client_assertion_type of '%s'.", clientAssertionJWTBearerType)) } - // for backward compatibility - if f.JWTHelper == nil { - f.JWTHelper = &JWTHelper{ - JWTStrategy: nil, - Config: f.Config, - } - } - // Parse the assertion - token, parsedToken, isJWE, err := f.newToken(assertion, "client_assertion", ErrInvalidClient) + token, parsedToken, isJWE, err := newToken(assertion, "client_assertion", ErrInvalidClient) if err != nil { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to parse the client_assertion").WithWrap(err).WithDebug(err.Error())) } @@ -133,7 +103,9 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The client_assertion uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", parsedToken.Headers[0].Algorithm, oidcClient.GetTokenEndpointAuthSigningAlgorithm())) } - if token, parsedToken, err = f.ValidateParsedAssertionWithClient(ctx, "client_assertion", assertion, token, parsedToken, oidcClient, false, ErrInvalidClient); err != nil { + ctx = context.WithValue(ctx, AssertionTypeContextKey, "client_assertion") + ctx = context.WithValue(ctx, BaseErrorContextKey, ErrInvalidClient) + if token, parsedToken, err = ValidateParsedAssertionWithClient(ctx, f.Config, assertion, token, parsedToken, oidcClient, false); err != nil { return nil, err } @@ -248,43 +220,6 @@ func (f *Fosite) checkClientSecret(ctx context.Context, client Client, clientSec return err } -func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool) (interface{}, error) { - keys := set.Keys - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any key.")) - } - - kid, ok := t.Header["kid"].(string) - if ok { - keys = set.Key(kid) - } - - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid)) - } - - for _, key := range keys { - if key.Use != "sig" { - continue - } - if expectsRSAKey { - if k, ok := key.Key.(*rsa.PublicKey); ok { - return k, nil - } - } else { - if k, ok := key.Key.(*ecdsa.PublicKey); ok { - return k, nil - } - } - } - - if expectsRSAKey { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } else { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } -} - func clientCredentialsFromRequest(r *http.Request, form url.Values) (clientID, clientSecret string, err error) { if id, secret, ok := r.BasicAuth(); !ok { return clientCredentialsFromRequestBody(form, true) diff --git a/client_authentication_test.go b/client_authentication_test.go index c8f951629..8dabd83e9 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -107,17 +107,15 @@ func TestAuthenticateClient(t *testing.T) { ClientSecretsHasher: hasher, TokenURL: "token-url", HTTPClient: retryablehttp.NewClient(), + JWTStrategy: jwt.NewDefaultStrategy( + func(ctx context.Context, context *jwt.KeyContext) (interface{}, error) { + return encKey, nil + }), } f := &Fosite{ Store: storage.NewMemoryStore(), Config: config, - JWTHelper: &JWTHelper{ - Config: config, - JWTStrategy: jwt.NewDefaultStrategy(func(ctx context.Context, context *jwt.KeyContext) (interface{}, error) { - return encKey, nil - }), - }, } barSecret, err := hasher.Hash(context.TODO(), []byte("bar")) diff --git a/config.go b/config.go index 1b50eb70c..aed2374a1 100644 --- a/config.go +++ b/config.go @@ -300,3 +300,9 @@ type PushedAuthorizeRequestConfigProvider interface { // must contain the PAR request_uri. EnforcePushedAuthorize(ctx context.Context) bool } + +// JWTStrategyProvider returns the provider for configuring the JWT strategy. +type JWTStrategyProvider interface { + // GetJWTStrategy returns the JWT strategy. + GetJWTStrategy(ctx context.Context) jwt.Strategy +} diff --git a/config_default.go b/config_default.go index 7f2e2487e..8a3bd7aa1 100644 --- a/config_default.go +++ b/config_default.go @@ -62,6 +62,7 @@ var ( _ RevocationHandlersProvider = (*Config)(nil) _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) + _ JWTStrategyProvider = (*Config)(nil) ) type Config struct { @@ -212,6 +213,9 @@ type Config struct { // IsPushedAuthorizeEnforced enforces pushed authorization request for /authorize IsPushedAuthorizeEnforced bool + + // JWTStrategy is used to provide additional JWT encrypt/decrypt/sign/verify capabilities + JWTStrategy jwt.Strategy } func (c *Config) GetGlobalSecret(ctx context.Context) ([]byte, error) { @@ -488,3 +492,8 @@ func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) time.Dur func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool { return c.IsPushedAuthorizeEnforced } + +// GetJWTStrategy returns the JWT strategy. +func (c *Config) GetJWTStrategy(ctx context.Context) jwt.Strategy { + return c.JWTStrategy +} diff --git a/context.go b/context.go index d8b2bc3fd..3d72a3a73 100644 --- a/context.go +++ b/context.go @@ -19,4 +19,7 @@ const ( AuthorizeResponseContextKey = ContextKey("authorizeResponse") // PushedAuthorizeResponseContextKey is the response context PushedAuthorizeResponseContextKey = ContextKey("pushedAuthorizeResponse") + + AssertionTypeContextKey = ContextKey("assertionType") + BaseErrorContextKey = ContextKey("baseError") ) diff --git a/fosite.go b/fosite.go index f73ad3a95..445213629 100644 --- a/fosite.go +++ b/fosite.go @@ -142,8 +142,6 @@ type Fosite struct { Store Storage Config Configurator - - *JWTHelper } // GetMinParameterEntropy returns MinParameterEntropy if set. Defaults to fosite.MinParameterEntropy. diff --git a/fosite_jwt.go b/fosite_jwt.go index 83dc084e2..cd0526cf3 100644 --- a/fosite_jwt.go +++ b/fosite_jwt.go @@ -36,42 +36,43 @@ type JWTValidationConfig struct { NoneAlgAllowed bool `json:"none"` } -// JWTHelper provides JWT helper functions that is used across -type JWTHelper struct { - // JWTStrategy is the strategy used to build and validate JWTs - JWTStrategy fjwt.Strategy - Config Configurator -} - // ValidateParsedAssertionWithClient validates the parsed assertion based on the jwks_uri, jwks etc. configured on the client -func (f *JWTHelper) ValidateParsedAssertionWithClient(ctx context.Context, assertionType string, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, oidcClient OpenIDConnectClient, isNoneAlgAllowed bool, baseError *RFC6749Error) ( +func ValidateParsedAssertionWithClient(ctx context.Context, config Configurator, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, oidcClient OpenIDConnectClient, isNoneAlgAllowed bool) ( *fjwt.Token, *jwt.JSONWebToken, error) { + jwksURI := oidcClient.GetJSONWebKeysURI() jwks := oidcClient.GetJSONWebKeys() allowedKeys := []string{} - if oidcClientEx, ok := oidcClient.(ClientWithAllowedVerificationKeys); ok { - allowedKeys = oidcClientEx.AllowedVerificationKeys() + if c, ok := oidcClient.(ClientWithAllowedVerificationKeys); ok { + allowedKeys = c.AllowedVerificationKeys() } - return f.ValidateParsedAssertion(ctx, assertionType, assertion, token, parsedToken, &JWTValidationConfig{ + return ValidateParsedAssertion(ctx, config, assertion, token, parsedToken, &JWTValidationConfig{ AllowedSigningKeys: allowedKeys, AllowedSigningAlgs: []string{oidcClient.GetTokenEndpointAuthSigningAlgorithm()}, JSONWebKeysURI: jwksURI, JSONWebKeys: jwks, NoneAlgAllowed: isNoneAlgAllowed, - }, baseError) + }) } // ValidateParsedAssertion validates the parsed assertion based on the jwks_uri, jwks etc. that is passed in -func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType string, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, config *JWTValidationConfig, baseError *RFC6749Error) ( +func ValidateParsedAssertion(ctx context.Context, config Configurator, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, verificationConfig *JWTValidationConfig) ( *fjwt.Token, *jwt.JSONWebToken, error) { var err error + baseError := getBaseError(ctx) + assertionType := getAssertionType(ctx) - if f.JWTStrategy != nil && len(token.Method) == 0 { // JWE + var jwtStrategy fjwt.Strategy + if c, ok := config.(JWTStrategyProvider); ok { + jwtStrategy = c.GetJWTStrategy(ctx) + } + + if jwtStrategy != nil && len(token.Method) == 0 { // JWE alg, _ := token.Header["alg"].(string) enc, _ := token.Header["enc"].(string) - assertion, err = f.JWTStrategy.DecryptWithSettings(ctx, + assertion, err = jwtStrategy.DecryptWithSettings(ctx, &fjwt.KeyContext{ EncryptionKeyID: parsedToken.Headers[0].KeyID, EncryptionAlgorithm: alg, @@ -103,7 +104,7 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) } token.Claims = mapClaims - err = f.validateJWTClaims(ctx, mapClaims, assertionType, baseError) + err = validateJWTClaims(ctx, mapClaims, assertionType, baseError) if err != nil { return nil, nil, err } @@ -113,7 +114,7 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s } if token.Method == fjwt.SigningMethodNone { - if !config.NoneAlgAllowed { + if !verificationConfig.NoneAlgAllowed { return nil, nil, errorsx.WithStack(baseError.WithHintf("'none' is disallowed as a signing method of the '%s'.", assertionType)) } @@ -122,22 +123,22 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s claims := token.Claims signingAlg := parsedToken.Headers[0].Algorithm - if len(config.AllowedSigningAlgs) > 0 && !Arguments(config.AllowedSigningAlgs).Has(signingAlg) { + if len(verificationConfig.AllowedSigningAlgs) > 0 && !Arguments(verificationConfig.AllowedSigningAlgs).Has(signingAlg) { return nil, nil, errorsx.WithStack(baseError.WithHintf("The 'alg' used in the '%s' is not allowed.", assertionType)) } signingKey := parsedToken.Headers[0].KeyID - if len(config.AllowedSigningKeys) > 0 && !Arguments(config.AllowedSigningKeys).Has(signingKey) { + if len(verificationConfig.AllowedSigningKeys) > 0 && !Arguments(verificationConfig.AllowedSigningKeys).Has(signingKey) { return nil, nil, errorsx.WithStack(baseError.WithHintf("The 'kid' used in the '%s' is not allowed.", assertionType)) } // Validate signature - if config.JSONWebKeysURI == "" && config.JSONWebKeys == nil { - if f.JWTStrategy == nil { + if verificationConfig.JSONWebKeysURI == "" && verificationConfig.JSONWebKeys == nil { + if jwtStrategy == nil { return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) } - _, err := f.JWTStrategy.ValidateWithSettings(ctx, + _, err := jwtStrategy.ValidateWithSettings(ctx, &fjwt.KeyContext{ SigningKeyID: parsedToken.Headers[0].KeyID, SigningAlgorithm: parsedToken.Headers[0].Algorithm, @@ -151,19 +152,19 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s var err error switch token.Method { case jose.RS256, jose.RS384, jose.RS512: - key, err = f.findPublicJWK(ctx, token, config.JSONWebKeysURI, config.JSONWebKeys, true, baseError) + key, err = findPublicJWK(ctx, config, token, verificationConfig.JSONWebKeysURI, verificationConfig.JSONWebKeys, true, baseError) if err != nil { return nil, nil, wrapSigningKeyFailure( baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) } case jose.ES256, jose.ES384, jose.ES512: - key, err = f.findPublicJWK(ctx, token, config.JSONWebKeysURI, config.JSONWebKeys, false, baseError) + key, err = findPublicJWK(ctx, config, token, verificationConfig.JSONWebKeysURI, verificationConfig.JSONWebKeys, false, baseError) if err != nil { return nil, nil, wrapSigningKeyFailure( baseError.WithHint("Unable to retrieve ECDSA signing key from the JSON Web Key Set."), err) } case jose.PS256, jose.PS384, jose.PS512: - key, err = f.findPublicJWK(ctx, token, config.JSONWebKeysURI, config.JSONWebKeys, true, baseError) + key, err = findPublicJWK(ctx, config, token, verificationConfig.JSONWebKeysURI, verificationConfig.JSONWebKeys, true, baseError) if err != nil { return nil, nil, wrapSigningKeyFailure( baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) @@ -184,7 +185,7 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s } } - err = f.validateJWTClaims(ctx, claims, assertionType, baseError) + err = validateJWTClaims(ctx, claims, assertionType, baseError) if err != nil { return nil, nil, err } @@ -192,7 +193,7 @@ func (f *JWTHelper) ValidateParsedAssertion(ctx context.Context, assertionType s return token, parsedToken, nil } -func (f *JWTHelper) validateJWTClaims(ctx context.Context, claims fjwt.MapClaims, assertionType string, baseError *RFC6749Error) error { +func validateJWTClaims(ctx context.Context, claims fjwt.MapClaims, assertionType string, baseError *RFC6749Error) error { // Validate claims // This validation is performed to be backwards compatible // with jwt-go library behavior @@ -218,7 +219,7 @@ func (f *JWTHelper) validateJWTClaims(ctx context.Context, claims fjwt.MapClaims return nil } -func (f *JWTHelper) newToken(assertion string, assertionType string, baseError *RFC6749Error) (*fjwt.Token, *jwt.JSONWebToken, bool, error) { +func newToken(assertion string, assertionType string, baseError *RFC6749Error) (*fjwt.Token, *jwt.JSONWebToken, bool, error) { var err error var parsedToken *jwt.JSONWebToken @@ -267,7 +268,7 @@ func (f *JWTHelper) newToken(assertion string, assertionType string, baseError * return token, parsedToken, isJWE, nil } -func (f *JWTHelper) findPublicKey(t *fjwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { +func findPublicKey(t *fjwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { keys := set.Keys if len(keys) == 0 { return nil, errorsx.WithStack(baseError.WithHint("The retrieved JSON Web Key Set does not contain any keys")) @@ -304,26 +305,26 @@ func (f *JWTHelper) findPublicKey(t *fjwt.Token, set *jose.JSONWebKeySet, expect return nil, errorsx.WithStack(baseError.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) } -func (f *JWTHelper) findPublicJWK(ctx context.Context, t *fjwt.Token, jwksURI string, jwks *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { +func findPublicJWK(ctx context.Context, config Configurator, t *fjwt.Token, jwksURI string, jwks *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { if jwks != nil { - return f.findPublicKey(t, jwks, expectsRSAKey, baseError) + return findPublicKey(t, jwks, expectsRSAKey, baseError) } - keys, err := f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, false) + keys, err := config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, false) if err != nil { return nil, err } - if key, err := f.findPublicKey(t, keys, expectsRSAKey, baseError); err == nil { + if key, err := findPublicKey(t, keys, expectsRSAKey, baseError); err == nil { return key, nil } - keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, true) + keys, err = config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, true) if err != nil { return nil, errorsx.WithStack(baseError.WithHintf(fmt.Sprintf("%s", err))) } - return f.findPublicKey(t, keys, expectsRSAKey, baseError) + return findPublicKey(t, keys, expectsRSAKey, baseError) } // if underline value of v is not a pointer @@ -336,3 +337,19 @@ func pointer(v interface{}) interface{} { } return v } + +func getBaseError(ctx context.Context) *RFC6749Error { + if e, ok := ctx.Value(BaseErrorContextKey).(*RFC6749Error); ok { + return e + } + + return ErrInvalidClient +} + +func getAssertionType(ctx context.Context) string { + if at, ok := ctx.Value(AssertionTypeContextKey).(string); ok { + return at + } + + return "assertion" +}