From 335b4828f0facd40f49bc5d27f2dfa07e0e528f6 Mon Sep 17 00:00:00 2001 From: Pedro Costa <550684+pnmcosta@users.noreply.github.com> Date: Tue, 4 Feb 2025 19:53:21 +0000 Subject: [PATCH] WIP auth manager --- client/cmd/testutil_test.go | 2 +- client/internal/engine_test.go | 2 +- client/server/server_test.go | 2 +- go.mod | 2 +- go.sum | 4 +- management/client/client_test.go | 2 +- management/cmd/management.go | 28 +- management/server/account.go | 320 ++++++---------- management/server/account_test.go | 231 ++++-------- .../{jwtclaims => auth/jwt}/extractor.go | 94 ++--- management/server/auth/jwt/validator.go | 303 +++++++++++++++ management/server/auth/manager.go | 168 +++++++++ management/server/auth/manager_mock.go | 54 +++ management/server/auth/manager_test.go | 206 +++++++++++ management/server/config.go | 7 - management/server/context/auth.go | 56 +++ management/server/grpcserver.go | 56 +-- management/server/http/handler.go | 63 ++-- .../handlers/accounts/accounts_handler.go | 37 +- .../accounts/accounts_handler_test.go | 23 +- .../http/handlers/dns/dns_settings_handler.go | 36 +- .../handlers/dns/dns_settings_handler_test.go | 20 +- .../http/handlers/dns/nameservers_handler.go | 48 ++- .../handlers/dns/nameservers_handler_test.go | 20 +- .../http/handlers/events/events_handler.go | 25 +- .../handlers/events/events_handler_test.go | 20 +- .../http/handlers/groups/groups_handler.go | 43 ++- .../handlers/groups/groups_handler_test.go | 30 +- .../server/http/handlers/networks/handler.go | 53 ++- .../handlers/networks/resources_handler.go | 51 ++- .../http/handlers/networks/routers_handler.go | 45 +-- .../http/handlers/peers/peers_handler.go | 31 +- .../http/handlers/peers/peers_handler_test.go | 38 +- .../policies/geolocation_handler_test.go | 24 +- .../handlers/policies/geolocations_handler.go | 19 +- .../handlers/policies/policies_handler.go | 42 +-- .../policies/policies_handler_test.go | 24 +- .../policies/posture_checks_handler.go | 38 +- .../policies/posture_checks_handler_test.go | 24 +- .../http/handlers/routes/routes_handler.go | 39 +- .../handlers/routes/routes_handler_test.go | 48 +-- .../handlers/setup_keys/setupkeys_handler.go | 36 +- .../setup_keys/setupkeys_handler_test.go | 25 +- .../server/http/handlers/users/pat_handler.go | 32 +- .../http/handlers/users/pat_handler_test.go | 23 +- .../http/handlers/users/users_handler.go | 46 +-- .../http/handlers/users/users_handler_test.go | 41 +- .../server/http/middleware/access_control.go | 26 +- .../server/http/middleware/auth_middleware.go | 156 ++++---- .../http/middleware/auth_middleware_test.go | 64 ++-- .../http/testing/testing_tools/tools.go | 10 +- management/server/jwtclaims/claims.go | 2 +- management/server/jwtclaims/extractor_test.go | 227 ------------ management/server/jwtclaims/jwtValidator.go | 349 ------------------ management/server/management_proto_test.go | 2 +- management/server/management_test.go | 4 +- management/server/mock_server/account_mock.go | 41 +- management/server/user.go | 28 +- management/server/user_test.go | 10 +- 59 files changed, 1652 insertions(+), 1848 deletions(-) rename management/server/{jwtclaims => auth/jwt}/extractor.go (65%) create mode 100644 management/server/auth/jwt/validator.go create mode 100644 management/server/auth/manager.go create mode 100644 management/server/auth/manager_mock.go create mode 100644 management/server/auth/manager_test.go create mode 100644 management/server/context/auth.go delete mode 100644 management/server/jwtclaims/extractor_test.go delete mode 100644 management/server/jwtclaims/jwtValidator.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index e3e644357e7..e0d78404873 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -95,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index ca49eca09f6..e32e262b925 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1226,7 +1226,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 128de8e020f..d6b651a795e 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -134,7 +134,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 77d570662fa..8bd5ef861a4 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250207083520-39827e937b0f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 4b9e90eba3d..26b2a33be87 100644 --- a/go.sum +++ b/go.sum @@ -529,8 +529,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250207083520-39827e937b0f h1:yqx1esCmLMntVbjG4Oha6nTLGS5jFhoZ5TNNBkVtQ4c= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250207083520-39827e937b0f/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= diff --git a/management/client/client_test.go b/management/client/client_test.go index 3e498a5eae9..cb2cf5cb46d 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -78,7 +78,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 1c8fca8dceb..9712f04aac1 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -39,13 +39,12 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/auth" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" nbhttp "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" @@ -255,24 +254,13 @@ var ( tlsEnabled = true } - jwtValidator, err := jwtclaims.NewJWTValidator( - ctx, + authManager := auth.NewManager(store, config.HttpConfig.AuthIssuer, - config.GetAuthAudiences(), + config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, - config.HttpConfig.IdpSignKeyRefreshEnabled, - ) - if err != nil { - return fmt.Errorf("failed creating JWT validator: %v", err) - } - - httpAPIAuthCfg := configs.AuthCfg{ - Issuer: config.HttpConfig.AuthIssuer, - Audience: config.HttpConfig.AuthAudience, - UserIDClaim: config.HttpConfig.AuthUserIDClaim, - KeysLocation: config.HttpConfig.AuthKeysLocation, - } - + config.HttpConfig.AuthUserIDClaim, + config.GetAuthAudiences(), + config.HttpConfig.IdpSignKeyRefreshEnabled) userManager := users.NewManager(store) settingsManager := settings.NewManager(store) permissionsManager := permissions.NewManager(userManager, settingsManager) @@ -281,7 +269,7 @@ var ( routersManager := routers.NewManager(store, permissionsManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) - httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, config, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -290,7 +278,7 @@ var ( ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 0f3b5e6eb7b..ed3497868c2 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2,11 +2,8 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" "errors" "fmt" - "hash/crc32" "math/rand" "net" "net/netip" @@ -24,10 +21,10 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" @@ -77,13 +74,14 @@ type AccountManager interface { GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) + // deprecated, use GetAccountIDFromUserAuth instead GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) + GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error - MarkPATUsed(ctx context.Context, tokenID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) + // deprecated, use GetUserFromUserAuth instead GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error @@ -149,6 +147,7 @@ type AccountManager interface { GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error UpdateAccountPeers(ctx context.Context, accountID string) + SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error } type DefaultAccountManager struct { @@ -937,11 +936,11 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun } // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes -func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth, primaryDomain bool, ) error { - if claims.Domain == "" { - log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + if userAuth.Domain == "" { + log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", userAuth) return nil } @@ -954,11 +953,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return err } - if domainIsUpToDate(accountDomain, domainCategory, claims) { + if domainIsUpToDate(accountDomain, domainCategory, userAuth) { return nil } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err @@ -967,13 +966,13 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx newDomain := accountDomain newCategoty := domainCategory - lowerDomain := strings.ToLower(claims.Domain) + lowerDomain := strings.ToLower(userAuth.Domain) if accountDomain != lowerDomain && user.HasAdminPower() { newDomain = lowerDomain } if accountDomain == lowerDomain { - newCategoty = claims.DomainCategory + newCategoty = userAuth.DomainCategory } return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) @@ -989,16 +988,16 @@ func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, userAccountID string, domainAccountID string, - claims jwtclaims.AuthorizationClaims, + userAuth nbcontext.UserAuth, ) error { primaryDomain := domainAccountID == "" || userAccountID == domainAccountID - err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) + err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID) if err != nil { return err } @@ -1008,20 +1007,20 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { - if claims.UserId == "" { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { + if userAuth.UserId == "" { return "", fmt.Errorf("user ID is empty") } - lowerDomain := strings.ToLower(claims.Domain) + lowerDomain := strings.ToLower(userAuth.Domain) - newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) + newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain) if err != nil { return "", err } newAccount.Domain = lowerDomain - newAccount.DomainCategory = claims.DomainCategory + newAccount.DomainCategory = userAuth.DomainCategory newAccount.IsDomainPrimaryAccount = true err = am.Store.SaveAccount(ctx, newAccount) @@ -1029,33 +1028,33 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id) if err != nil { return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil) return newAccount.Id, nil } -func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - newUser := types.NewRegularUser(claims.UserId) + newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) if err != nil { return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID) if err != nil { return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) return domainAccountID, nil } @@ -1095,76 +1094,11 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str return nil } -// MarkPATUsed marks a personal access token as used -func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) -} - // GetAccount returns an account associated with this account ID. func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { return am.Store.GetAccount(ctx, accountID) } -// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. -func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { - user, pat, err = am.extractPATFromToken(ctx, token) - if err != nil { - return nil, nil, "", "", err - } - - domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) - if err != nil { - return nil, nil, "", "", err - } - - return user, pat, domain, category, nil -} - -// extractPATFromToken validates the token structure and retrieves associated User and PAT. -func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) { - if len(token) != types.PATLength { - return nil, nil, fmt.Errorf("token has incorrect length") - } - - prefix := token[:len(types.PATPrefix)] - if prefix != types.PATPrefix { - return nil, nil, fmt.Errorf("token has wrong prefix") - } - secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] - encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] - - verificationChecksum, err := base62.Decode(encodedChecksum) - if err != nil { - return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) - } - - secretChecksum := crc32.ChecksumIEEE([]byte(secret)) - if secretChecksum != verificationChecksum { - return nil, nil, fmt.Errorf("token checksum does not match") - } - - hashedToken := sha256.Sum256([]byte(token)) - encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - - var user *types.User - var pat *types.PersonalAccessToken - - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) - if err != nil { - return err - } - - user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) - return err - }) - if err != nil { - return nil, nil, err - } - - return user, pat, nil -} - // GetAccountByID returns an account associated with this account ID. func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) @@ -1179,58 +1113,65 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return am.Store.GetAccount(ctx, accountID) } -// GetAccountIDFromToken returns an account ID associated with this token. -func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - if claims.UserId == "" { +func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. - claims.Domain = am.singleAccountModeDomain - claims.DomainCategory = types.PrivateCategory + userAuth.Domain = am.singleAccountModeDomain + userAuth.DomainCategory = types.PrivateCategory log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) + accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth) if err != nil { return "", "", err } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { // this is not really possible because we got an account by user ID - return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) + return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) + } + + if userAuth.IsChild { + return accountID, user.Id, nil } if user.AccountID != accountID { - return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID) } - if !user.IsServiceUser && claims.Invited { + if !user.IsServiceUser && userAuth.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { return "", "", err } } - if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { + return accountID, user.Id, nil +} + +// GetAccountIDFromToken returns an account ID associated with this token. +func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + userAuth, err := nbcontext.GetUserAuthFromContext(ctx) + if err != nil { return "", "", err } - - return accountID, user.Id, nil + return am.GetAccountIDFromUserAuth(ctx, userAuth) } // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { - if claim, exists := claims.Raw[jwtclaims.IsToken]; exists { - if isToken, ok := claim.(bool); ok && isToken { - return nil - } +// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager +func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { + if userAuth.IsChild || userAuth.IsPAT { + return nil } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) if err != nil { return err } @@ -1244,9 +1185,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return nil } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId) defer func() { if unlockAccount != nil { unlockAccount() @@ -1258,17 +1197,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var hasChanges bool var user *types.User err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) + user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } - changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, userAuth.Groups) if err != nil { return fmt.Errorf("error getting JWT groups changes: %w", err) } @@ -1293,7 +1232,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -1303,7 +1242,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st groupsMap[group.ID] = group } - peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId) + peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } @@ -1317,7 +1256,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -1335,45 +1274,45 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) if err != nil { - log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) } else { meta := map[string]any{ "group": group.Name, "group_id": group.ID, "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } - am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) + am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupAddedToUser, meta) } } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) if err != nil { - log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) } else { meta := map[string]any{ "group": group.Name, "group_id": group.ID, "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } - am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) + am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupRemovedFromUser, meta) } } if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) if err != nil { return err } - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) if err != nil { return err } if removedGroupAffectsPeers || newGroupsAffectsPeers { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.UpdateAccountPeers(ctx, accountID) + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) + am.UpdateAccountPeers(ctx, userAuth.AccountId) } } @@ -1398,24 +1337,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { +// +// Impersonated UserAuth -> checks that account exists +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", - claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory) - if claims.UserId == "" { + if userAuth.UserId == "" { return "", errors.New(emptyUserID) } - if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) + if userAuth.IsChild { + exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) + if err != nil || !exists { + return "", err + } + return userAuth.AccountId, nil } - if claims.AccountId != "" { - return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) { + return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain) + } + + if userAuth.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, userAuth) } // We checked if the domain has a primary account already - domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, userAuth.Domain) if cancel != nil { defer cancel() } @@ -1423,14 +1372,14 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } if userAccountID != "" { - if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, userAuth); err != nil { return "", err } @@ -1438,10 +1387,10 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context } if domainAccountID != "" { - return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + return am.addNewUserToDomainAccount(ctx, domainAccountID, userAuth) } - return am.addNewPrivateAccount(ctx, domainAccountID, claims) + return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) @@ -1469,40 +1418,40 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont return domainAccountID, cancel, nil } -func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + if userAccountID != userAuth.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) } - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err } - if domainIsUpToDate(accountDomain, domainCategory, claims) { - return claims.AccountId, nil + if domainIsUpToDate(accountDomain, domainCategory, userAuth) { + return userAuth.AccountId, nil } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err } - err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) + err = am.handleExistingUserAccount(ctx, userAuth.AccountId, domainAccountID, userAuth) if err != nil { return "", err } - return claims.AccountId, nil + return userAuth.AccountId, nil } func handleNotFound(err error) error { @@ -1517,8 +1466,8 @@ func handleNotFound(err error) error { return nil } -func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain +func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool { + return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -1600,34 +1549,6 @@ func (am *DefaultAccountManager) GetDNSDomain() string { return am.dnsDomain } -// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT -// group propagation and set the list of groups with access permissions. -func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - accountID, _, err := am.GetAccountIDFromToken(ctx, claims) - if err != nil { - return err - } - - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return err - } - - // Ensures JWT group synchronization to the management is enabled before, - // filtering access based on the allowed groups. - if settings != nil && settings.JWTGroupsEnabled { - if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { - userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - - if !userHasAllowedGroup(allowedGroups, userJWTGroups) { - return fmt.Errorf("user does not belong to any of the allowed JWT groups") - } - } - } - - return nil -} - func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) am.UpdateAccountPeers(ctx, accountID) @@ -1785,39 +1706,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty return acc } -// extractJWTGroups extracts the group names from a JWT token's claims. -func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string { - userJWTGroups := make([]string, 0) - - if claim, ok := claims.Raw[claimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { - for _, g := range claimGroups { - if group, ok := g.(string); ok { - userJWTGroups = append(userJWTGroups, group) - } else { - log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) - } - } - } - } else { - log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName) - } - - return userJWTGroups -} - -// userHasAllowedGroup checks if a user belongs to any of the allowed groups. -func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { - for _, userGroup := range userGroups { - for _, allowedGroup := range allowedGroups { - if userGroup == allowedGroup { - return true - } - } - } - return false -} - // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. diff --git a/management/server/account_test.go b/management/server/account_test.go index 0a7f9119bfd..848987655d5 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2,8 +2,6 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" "encoding/json" "fmt" "io" @@ -15,8 +13,6 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/util" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -30,7 +26,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -437,7 +433,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { - type initUserParams jwtclaims.AuthorizationClaims + type initUserParams nbcontext.UserAuth var ( publicDomain = "public.com" @@ -460,7 +456,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCases := []struct { name string - inputClaims jwtclaims.AuthorizationClaims + inputClaims nbcontext.UserAuth inputInitUserParams initUserParams inputUpdateAttrs bool inputUpdateClaimAccount bool @@ -475,7 +471,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }{ { name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: publicDomain, UserId: "pub-domain-user", DomainCategory: types.PublicCategory, @@ -492,7 +488,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: unknownDomain, UserId: "unknown-domain-user", DomainCategory: types.UnknownCategory, @@ -509,7 +505,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: privateDomain, UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -526,7 +522,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: privateDomain, UserId: "new-pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -544,7 +540,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -561,7 +557,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -579,7 +575,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ + inputClaims: nbcontext.UserAuth{ Domain: "", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -608,7 +604,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -616,7 +612,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims) + accountID, _, err = manager.GetAccountIDFromUserAuth(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -635,14 +631,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { } } -func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { +func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { userId := "user-id" domain := "test.domain" - _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization @@ -650,65 +644,50 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") - - claims := jwtclaims.AuthorizationClaims{ + claims := nbcontext.UserAuth{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount Domain: domain, UserId: userId, DomainCategory: "test-category", - Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}, + Groups: []string{"group1", "group2"}, } - t.Run("JWT groups disabled", func(t *testing.T) { - accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) - require.NoError(t, err, "get account by token failed") - + err := manager.SyncUserJWTGroups(context.Background(), claims) + require.NoError(t, err, "synt user jwt groups failed") account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get account failed") - require.Len(t, account.Groups, 1, "only ALL group should exists") }) - t.Run("JWT groups enabled without claim name", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - - accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) - require.NoError(t, err, "get account by token failed") - + err = manager.SyncUserJWTGroups(context.Background(), claims) + require.NoError(t, err, "synt user jwt groups failed") account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get account failed") - require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) - t.Run("JWT groups enabled", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsClaimName = "idp-groups" err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - - accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) - require.NoError(t, err, "get account by token failed") - + err = manager.SyncUserJWTGroups(context.Background(), claims) + require.NoError(t, err, "synt user jwt groups failed") account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get account failed") - require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*types.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } - g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match") - g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") @@ -716,88 +695,6 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { }) } -func TestAccountManager_GetAccountFromPAT(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - if err != nil { - t.Fatalf("Error when creating store: %s", err) - } - t.Cleanup(cleanup) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - - token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" - hashedToken := sha256.Sum256([]byte(token)) - encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &types.User{ - Id: "someUser", - PATs: map[string]*types.PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - UserID: "someUser", - HashedToken: encodedHashedToken, - }, - }, - } - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } - - am := DefaultAccountManager{ - Store: store, - } - - user, pat, _, _, err := am.GetPATInfo(context.Background(), token) - if err != nil { - t.Fatalf("Error when getting Account from PAT: %s", err) - } - - assert.Equal(t, "account_id", user.AccountID) - assert.Equal(t, "someUser", user.Id) - assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID) -} - -func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - if err != nil { - t.Fatalf("Error when creating store: %s", err) - } - t.Cleanup(cleanup) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - - token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" - hashedToken := sha256.Sum256([]byte(token)) - encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &types.User{ - Id: "someUser", - PATs: map[string]*types.PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - HashedToken: encodedHashedToken, - }, - }, - } - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } - - am := DefaultAccountManager{ - Store: store, - } - - err = am.MarkPATUsed(context.Background(), "tokenId") - if err != nil { - t.Fatalf("Error when marking PAT used: %s", err) - } - - account, err = am.Store.GetAccount(context.Background(), "account_id") - if err != nil { - t.Fatalf("Error when getting account: %s", err) - } - assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero()) -} - func TestAccountManager_PrivateAccount(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -962,13 +859,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) { } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { - claims := jwtclaims.AuthorizationClaims{ + claims := nbcontext.UserAuth{ Domain: "example.com", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, } - publicClaims := jwtclaims.AuthorizationClaims{ + publicClaims := nbcontext.UserAuth{ Domain: "test.com", UserId: "public-domain-user", DomainCategory: types.PublicCategory, @@ -2683,11 +2580,13 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") t.Run("skip sync for token auth type", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group3"}, + IsPAT: true, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2696,11 +2595,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("empty jwt groups", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{}, } - err := manager.syncJWTGroups(context.Background(), "accountID", claims) + err := manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2709,11 +2609,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("jwt match existing api group", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group1"}, } - err := manager.syncJWTGroups(context.Background(), "accountID", claims) + err := manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2729,11 +2630,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group1"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2746,11 +2648,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add jwt group", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group1", "group2"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2759,11 +2662,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("existed group not update", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{"group2"}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{"group2"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2772,11 +2676,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add new group", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user2", - Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}}, + claims := nbcontext.UserAuth{ + UserId: "user2", + AccountId: "accountID", + Groups: []string{"group1", "group3"}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") @@ -2789,11 +2694,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when list is empty", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user1", - Raw: jwt.MapClaims{"groups": []interface{}{}}, + claims := nbcontext.UserAuth{ + UserId: "user1", + AccountId: "accountID", + Groups: []string{}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") @@ -2803,11 +2709,12 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { - claims := jwtclaims.AuthorizationClaims{ - UserId: "user2", - Raw: jwt.MapClaims{}, + claims := nbcontext.UserAuth{ + UserId: "user2", + AccountId: "accountID", + Groups: []string{}, } - err = manager.syncJWTGroups(context.Background(), "accountID", claims) + err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") diff --git a/management/server/jwtclaims/extractor.go b/management/server/auth/jwt/extractor.go similarity index 65% rename from management/server/jwtclaims/extractor.go rename to management/server/auth/jwt/extractor.go index 18214b43454..a3fbe868883 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/auth/jwt/extractor.go @@ -1,10 +1,13 @@ -package jwtclaims +package jwt import ( - "net/http" + "errors" "time" "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" + + nbcontext "github.com/netbirdio/netbird/management/server/context" ) const ( @@ -26,15 +29,14 @@ const ( IsToken = "is_token" ) -// ExtractClaims Extract function type -type ExtractClaims func(r *http.Request) AuthorizationClaims +var ( + errUserIDClaimEmpty = errors.New("user ID claim token value is empty") +) // ClaimsExtractor struct that holds the extract function type ClaimsExtractor struct { authAudience string userIDClaim string - - FromRequestContext ExtractClaims } // ClaimsExtractorOption is a function that configures the ClaimsExtractor @@ -54,13 +56,6 @@ func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption { } } -// WithFromRequestContext sets the function that extracts claims from the request context -func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption { - return func(c *ClaimsExtractor) { - c.FromRequestContext = ec - } -} - // NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature, // then it will use that logic. Uses ExtractClaimsFromRequestContext by default func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { @@ -68,65 +63,74 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { for _, option := range options { option(ce) } - if ce.FromRequestContext == nil { - ce.FromRequestContext = ce.fromRequestContext - } + if ce.userIDClaim == "" { ce.userIDClaim = UserIDClaim } return ce } -// FromToken extracts claims from the token (after auth) -func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { - claims := token.Claims.(jwt.MapClaims) - jwtClaims := AuthorizationClaims{ - Raw: claims, +func parseTime(timeString string) time.Time { + if timeString == "" { + return time.Time{} } + parsedTime, err := time.Parse(time.RFC3339, timeString) + if err != nil { + return time.Time{} + } + return parsedTime +} + +func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) { + claims := token.Claims.(jwt.MapClaims) + userAuth := nbcontext.UserAuth{} + userID, ok := claims[c.userIDClaim].(string) if !ok { - return jwtClaims + return userAuth, errUserIDClaimEmpty } - jwtClaims.UserId = userID + userAuth.UserId = userID accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] if ok { - jwtClaims.AccountId = accountIDClaim.(string) + userAuth.AccountId = accountIDClaim.(string) } domainClaim, ok := claims[c.authAudience+DomainIDSuffix] if ok { - jwtClaims.Domain = domainClaim.(string) + userAuth.Domain = domainClaim.(string) } domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] if ok { - jwtClaims.DomainCategory = domainCategoryClaim.(string) + userAuth.DomainCategory = domainCategoryClaim.(string) } LastLoginClaimString, ok := claims[c.authAudience+LastLoginSuffix] if ok { - jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string)) + userAuth.LastLogin = parseTime(LastLoginClaimString.(string)) } invitedBool, ok := claims[c.authAudience+Invited] if ok { - jwtClaims.Invited = invitedBool.(bool) + userAuth.Invited = invitedBool.(bool) } - return jwtClaims -} -func parseTime(timeString string) time.Time { - if timeString == "" { - return time.Time{} - } - parsedTime, err := time.Parse(time.RFC3339, timeString) - if err != nil { - return time.Time{} - } - return parsedTime + return userAuth, nil } -// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) -func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims { - if r.Context().Value(TokenUserProperty) == nil { - return AuthorizationClaims{} +func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string { + claims := token.Claims.(jwt.MapClaims) + userJWTGroups := make([]string, 0) + + if claim, ok := claims[claimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } else { + log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) + } + } + } + } else { + log.Debugf("JWT claim %q is not a string array", claimName) } - token := r.Context().Value(TokenUserProperty).(*jwt.Token) - return c.FromToken(token) + + return userJWTGroups } diff --git a/management/server/auth/jwt/validator.go b/management/server/auth/jwt/validator.go new file mode 100644 index 00000000000..bbafb2f4d41 --- /dev/null +++ b/management/server/auth/jwt/validator.go @@ -0,0 +1,303 @@ +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math/big" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt" + + log "github.com/sirupsen/logrus" +) + +// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation +type Jwks struct { + Keys []JSONWebKey `json:"keys"` + expiresInTime time.Time +} + +// The supported elliptic curves types +const ( + // p256 represents a cryptographic elliptical curve type. + p256 = "P-256" + + // p384 represents a cryptographic elliptical curve type. + p384 = "P-384" + + // p521 represents a cryptographic elliptical curve type. + p521 = "P-521" +) + +// JSONWebKey is a representation of a Jason Web Key +type JSONWebKey struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` + X5c []string `json:"x5c"` +} + +type Validator struct { + lock sync.Mutex + issuer string + audienceList []string + keysLocation string + idpSignkeyRefreshEnabled bool + keys *Jwks +} + +var ( + errKeyNotFound = errors.New("unable to find appropriate key") + errInvalidAudience = errors.New("invalid audience") + errInvalidIssuer = errors.New("invalid issuer") + errTokenEmpty = errors.New("required authorization token not found") + errTokenInvalid = errors.New("token is invalid") + errTokenParsing = errors.New("token could not be parsed") +) + +func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator { + keys, err := getPemKeys(keysLocation) + if err != nil { + log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err) + } + + return &Validator{ + keys: keys, + issuer: issuer, + audienceList: audienceList, + keysLocation: keysLocation, + idpSignkeyRefreshEnabled: idpSignkeyRefreshEnabled, + } +} + +func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { + return func(token *jwt.Token) (interface{}, error) { + // Verify 'aud' claim + var checkAud bool + for _, audience := range v.audienceList { + checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) + if checkAud { + break + } + } + if !checkAud { + return token, errInvalidAudience + } + + // Verify 'issuer' claim + checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false) + if !checkIss { + return token, errInvalidIssuer + } + + // If keys are rotated, verify the keys prior to token validation + if v.idpSignkeyRefreshEnabled { + // If the keys are invalid, retrieve new ones + // @todo propose a separate go routine to regularly check these to prevent blocking when actually + // validating the token + if !v.keys.stillValid() { + v.lock.Lock() + defer v.lock.Unlock() + + refreshedKeys, err := getPemKeys(v.keysLocation) + if err != nil { + log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) + refreshedKeys = v.keys + } + + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + + v.keys = refreshedKeys + } + } + + publicKey, err := getPublicKey(token, v.keys) + if err == nil { + return publicKey, nil + } + + msg := fmt.Sprintf("getPublicKey error: %s", err) + if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled { + msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err) + } + + log.WithContext(ctx).Error(msg) + + return nil, err + } +} + +// ValidateAndParse validates the token and returns the parsed token +func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { + // If the token is empty... + if token == "" { + // If we get here, the required token is missing + log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)") + return nil, errTokenEmpty + } + + // Now parse the token + parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx)) + + // Check if there was an error in parsing... + if err != nil { + err = fmt.Errorf("%w: %s", errTokenParsing, err) + log.WithContext(ctx).Error(err.Error()) + return nil, err + } + + // Check if the parsed token is valid... + if !parsedToken.Valid { + log.WithContext(ctx).Debug(errTokenInvalid.Error()) + return nil, errTokenInvalid + } + + return parsedToken, nil +} + +// stillValid returns true if the JSONWebKey still valid and have enough time to be used +func (jwks *Jwks) stillValid() bool { + return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) +} + +func getPemKeys(keysLocation string) (*Jwks, error) { + jwks := &Jwks{} + + url, err := url.ParseRequestURI(keysLocation) + if err != nil { + return jwks, err + } + + resp, err := http.Get(url.String()) + if err != nil { + return jwks, err + } + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(jwks) + if err != nil { + return jwks, err + } + + cacheControlHeader := resp.Header.Get("Cache-Control") + expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) + jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) + + return jwks, nil +} + +func getPublicKey(token *jwt.Token, jwks *Jwks) (interface{}, error) { + // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time + for k := range jwks.Keys { + if token.Header["kid"] != jwks.Keys[k].Kid { + continue + } + + if len(jwks.Keys[k].X5c) != 0 { + cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" + return jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) + } + + if jwks.Keys[k].Kty == "RSA" { + return getPublicKeyFromRSA(jwks.Keys[k]) + } + if jwks.Keys[k].Kty == "EC" { + return getPublicKeyFromECDSA(jwks.Keys[k]) + } + } + + return nil, errKeyNotFound +} + +func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { + if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" { + return nil, fmt.Errorf("ecdsa key incomplete") + } + + var xCoordinate []byte + if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil { + return nil, err + } + + var yCoordinate []byte + if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil { + return nil, err + } + + publicKey = &ecdsa.PublicKey{} + + var curve elliptic.Curve + switch jwk.Crv { + case p256: + curve = elliptic.P256() + case p384: + curve = elliptic.P384() + case p521: + curve = elliptic.P521() + } + + publicKey.Curve = curve + publicKey.X = big.NewInt(0).SetBytes(xCoordinate) + publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) + + return publicKey, nil +} + +func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) { + decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, err + } + decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, err + } + + var n, e big.Int + e.SetBytes(decodedE) + n.SetBytes(decodedN) + + return &rsa.PublicKey{ + E: int(e.Int64()), + N: &n, + }, nil +} + +// @todo propose min timeout, for example the cache-control from auth0 is 15s, we might as well not cache them at all for such short duration +// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header +func getMaxAgeFromCacheHeader(cacheControl string) int { + // Split into individual directives + directives := strings.Split(cacheControl, ",") + + for _, directive := range directives { + directive = strings.TrimSpace(directive) + if strings.HasPrefix(directive, "max-age=") { + // Extract the max-age value + maxAgeStr := strings.TrimPrefix(directive, "max-age=") + maxAge, err := strconv.Atoi(maxAgeStr) + if err != nil { + return 0 + } + + return maxAge + } + } + + return 0 +} diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go new file mode 100644 index 00000000000..922f95c0ce7 --- /dev/null +++ b/management/server/auth/manager.go @@ -0,0 +1,168 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "hash/crc32" + + "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/base62" + nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +var _ Manager = (*manager)(nil) + +type Manager interface { + ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + MarkPATUsed(ctx context.Context, tokenID string) error + GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) +} + +type manager struct { + store store.Store + + validator *nbjwt.Validator + extractor *nbjwt.ClaimsExtractor +} + +func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager { + // @note if invalid/missing parameters are sent the validator will instantiate + // but it will fail when validating and parsing the token + jwtValidator := nbjwt.NewValidator( + issuer, + allAudiences, + keysLocation, + idpRefreshKeys, + ) + + claimsExtractor := nbjwt.NewClaimsExtractor( + nbjwt.WithAudience(audience), + nbjwt.WithUserIDClaim(userIdClaim), + ) + + return &manager{ + store: store, + + validator: jwtValidator, + extractor: claimsExtractor, + } +} + +func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { + token, err := m.validator.ValidateAndParse(ctx, value) + if err != nil { + return nbcontext.UserAuth{}, nil, err + } + + userAuth, err := m.extractor.ToUserAuth(token) + return userAuth, token, err +} + +func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { + if userAuth.IsChild || userAuth.IsPAT { + return userAuth, nil + } + + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) + if err != nil { + return userAuth, err + } + + // Ensures JWT group synchronization to the management is enabled before, + // filtering access based on the allowed groups. + if settings != nil && settings.JWTGroupsEnabled { + if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { + userAuth.Groups = m.extractor.ToGroups(token, settings.JWTGroupsClaimName) + + if !userHasAllowedGroup(allowedGroups, userAuth.Groups) { + return userAuth, fmt.Errorf("user does not belong to any of the allowed JWT groups") + } + } + } + + return userAuth, nil +} + +// MarkPATUsed marks a personal access token as used +func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error { + return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) +} + +// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. +func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { + user, pat, err = am.extractPATFromToken(ctx, token) + if err != nil { + return nil, nil, "", "", err + } + + domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) + if err != nil { + return nil, nil, "", "", err + } + + return user, pat, domain, category, nil +} + +// extractPATFromToken validates the token structure and retrieves associated User and PAT. +func (am *manager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) { + if len(token) != types.PATLength { + return nil, nil, fmt.Errorf("PAT has incorrect length") + } + + prefix := token[:len(types.PATPrefix)] + if prefix != types.PATPrefix { + return nil, nil, fmt.Errorf("PAT has wrong prefix") + } + secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] + encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] + + verificationChecksum, err := base62.Decode(encodedChecksum) + if err != nil { + return nil, nil, fmt.Errorf("PAT checksum decoding failed: %w", err) + } + + secretChecksum := crc32.ChecksumIEEE([]byte(secret)) + if secretChecksum != verificationChecksum { + return nil, nil, fmt.Errorf("PAT checksum does not match") + } + + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:]) + + var user *types.User + var pat *types.PersonalAccessToken + + err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) + if err != nil { + return err + } + + user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) + return err + }) + if err != nil { + return nil, nil, err + } + + return user, pat, nil +} + +// userHasAllowedGroup checks if a user belongs to any of the allowed groups. +func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { + for _, userGroup := range userGroups { + for _, allowedGroup := range allowedGroups { + if userGroup == allowedGroup { + return true + } + } + } + return false +} diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go new file mode 100644 index 00000000000..9f9596c61a6 --- /dev/null +++ b/management/server/auth/manager_mock.go @@ -0,0 +1,54 @@ +package auth + +import ( + "context" + + "github.com/golang-jwt/jwt" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" +) + +var ( + _ Manager = (*MockManager)(nil) +) + +// @note really dislike this mocking approach but rather than have to do additional test refactoring. +type MockManager struct { + ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + MarkPATUsedFunc func(ctx context.Context, tokenID string) error + GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) +} + +// EnsureUserAccessByJWTGroups implements Manager. +func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { + if m.EnsureUserAccessByJWTGroupsFunc != nil { + return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token) + } + return nbcontext.UserAuth{}, nil +} + +// GetPATInfo implements Manager. +func (m *MockManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { + if m.GetAccountInfoFromPATFunc != nil { + return m.GetAccountInfoFromPATFunc(ctx, token) + } + return &types.User{}, &types.PersonalAccessToken{}, "", "", nil +} + +// MarkPATUsed implements Manager. +func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error { + if m.MarkPATUsedFunc != nil { + return m.MarkPATUsedFunc(ctx, tokenID) + } + return nil +} + +// ValidateAndParseToken implements Manager. +func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { + if m.ValidateAndParseTokenFunc != nil { + return m.ValidateAndParseTokenFunc(ctx, value) + } + return nbcontext.UserAuth{}, &jwt.Token{}, nil +} diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go new file mode 100644 index 00000000000..1ba334d5205 --- /dev/null +++ b/management/server/auth/manager_test.go @@ -0,0 +1,206 @@ +package auth_test + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "testing" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:]) + account := &types.Account{ + Id: "account_id", + Users: map[string]*types.User{"someUser": { + Id: "someUser", + PATs: map[string]*types.PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + UserID: "someUser", + HashedToken: encodedHashedToken, + }, + }, + }}, + } + + err = store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + manager := auth.NewManager(store, "", "", "", "", []string{}, false) + + user, pat, _, _, err := manager.GetPATInfo(context.Background(), token) + if err != nil { + t.Fatalf("Error when getting Account from PAT: %s", err) + } + + assert.Equal(t, "account_id", user.AccountID) + assert.Equal(t, "someUser", user.Id) + assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID) +} + +func TestAuthManager_MarkPATUsed(t *testing.T) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:]) + account := &types.Account{ + Id: "account_id", + Users: map[string]*types.User{"someUser": { + Id: "someUser", + PATs: map[string]*types.PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + HashedToken: encodedHashedToken, + }, + }, + }}, + } + + err = store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + manager := auth.NewManager(store, "", "", "", "", []string{}, false) + + err = manager.MarkPATUsed(context.Background(), "tokenId") + if err != nil { + t.Fatalf("Error when marking PAT used: %s", err) + } + + account, err = store.GetAccount(context.Background(), "account_id") + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero()) +} + +func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + userId := "user-id" + domain := "test.domain" + + account := &types.Account{ + Id: "account_id", + Domain: domain, + Users: map[string]*types.User{"someUser": { + Id: "someUser", + }}, + Settings: &types.Settings{}, + } + + err = store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + // this has been validated and parsed by ValidateAndParseToken + userAuth := nbcontext.UserAuth{ + AccountId: account.Id, + Domain: domain, + UserId: userId, + DomainCategory: "test-category", + // Groups: []string{"group1", "group2"}, + } + + // these tests only assert groups are parsed from token as per account settings + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}) + + manager := auth.NewManager(store, "", "", "", "", []string{}, false) + + t.Run("JWT groups disabled", func(t *testing.T) { + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups") + }) + + t.Run("User impersonated", func(t *testing.T) { + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups") + }) + + t.Run("User PAT", func(t *testing.T) { + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups") + }) + + t.Run("JWT groups enabled without claim name", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account missing groups claim name") + }) + + t.Run("JWT groups enabled without allowed groups", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + account.Settings.JWTGroupsClaimName = "idp-groups" + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + require.Len(t, userAuth.Groups, 0, "account missing allowed groups") + }) + + t.Run("User in allowed JWT groups", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + account.Settings.JWTGroupsClaimName = "idp-groups" + account.Settings.JWTAllowGroups = []string{"group1"} + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.NoError(t, err, "ensure user access by JWT groups failed") + + require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match") + }) + + t.Run("User not in allowed JWT groups", func(t *testing.T) { + account.Settings.JWTGroupsEnabled = true + account.Settings.JWTGroupsClaimName = "idp-groups" + account.Settings.JWTAllowGroups = []string{"not-a-group"} + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err, "save account failed") + + _, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) + require.Error(t, err, "ensure user access is not in allowed groups") + }) +} + +func TestAuthManager_ValidateAndParseToken(t *testing.T) { + // @todo should be an integration test that covers the validator and extractor with valid JWT +} diff --git a/management/server/config.go b/management/server/config.go index 397b5f0e66c..ce2ff4d1635 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -2,7 +2,6 @@ package server import ( "net/netip" - "net/url" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/store" @@ -180,9 +179,3 @@ type ReverseProxy struct { // trusted IP prefixes. TrustedPeers []netip.Prefix } - -// validateURL validates input http url -func validateURL(httpURL string) bool { - _, err := url.ParseRequestURI(httpURL) - return err == nil -} diff --git a/management/server/context/auth.go b/management/server/context/auth.go new file mode 100644 index 00000000000..c73f8fa1fdf --- /dev/null +++ b/management/server/context/auth.go @@ -0,0 +1,56 @@ +package context + +import ( + "context" + "fmt" + "net/http" + "time" +) + +type key int + +const ( + UserAuthContextKey key = iota +) + +type UserAuth struct { + // The account id the user is accessing + AccountId string + // The account domain + Domain string + // The account domain category, TBC values + DomainCategory string + // Indicates whether this user was invited, TBC logic + Invited bool + // Indicates whether this is a child account + IsChild bool + + // The user id + UserId string + // Last login time for this user + LastLogin time.Time + // The Groups the user belongs to on this account + Groups []string + + // Indicates whether this user has authenticated with a Personal Access Token + IsPAT bool +} + +func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) { + return GetUserAuthFromContext(r.Context()) +} + +func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request { + return r.WithContext(SetUserAuthInContext(r.Context(), userAuth)) +} + +func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) { + if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok { + return userAuth, nil + } + return UserAuth{}, fmt.Errorf("user auth not in context") +} + +func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context { + return context.WithValue(ctx, UserAuthContextKey, userAuth) +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index eec109ee970..e454fd536f2 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -19,8 +19,8 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/auth" nbContext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/settings" @@ -38,11 +38,10 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config secretsManager SecretsManager - jwtValidator jwtclaims.JWTValidator - jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager peerLocks sync.Map + authManager auth.Manager } // NewServer creates a new Management server @@ -55,29 +54,13 @@ func NewServer( secretsManager SecretsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager, + authManager auth.Manager, ) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err } - var jwtValidator jwtclaims.JWTValidator - - if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { - jwtValidator, err = jwtclaims.NewJWTValidator( - ctx, - config.HttpConfig.AuthIssuer, - config.GetAuthAudiences(), - config.HttpConfig.AuthKeysLocation, - config.HttpConfig.IdpSignKeyRefreshEnabled, - ) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) - } - } else { - log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware") - } - if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { @@ -88,16 +71,6 @@ func NewServer( } } - var audience, userIDClaim string - if config.HttpConfig != nil { - audience = config.HttpConfig.AuthAudience - userIDClaim = config.HttpConfig.AuthUserIDClaim - } - jwtClaimsExtractor := jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(audience), - jwtclaims.WithUserIDClaim(userIDClaim), - ) - return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -106,8 +79,7 @@ func NewServer( settingsManager: settingsManager, config: config, secretsManager: secretsManager, - jwtValidator: jwtValidator, - jwtClaimsExtractor: jwtClaimsExtractor, + authManager: authManager, appMetrics: appMetrics, ephemeralManager: ephemeralManager, }, nil @@ -281,26 +253,32 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p } func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { - if s.jwtValidator == nil { - return "", status.Error(codes.Internal, "no jwt validator set") + if s.authManager == nil { + return "", status.Errorf(codes.Internal, "missing auth manager") } - token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken) + userAuth, token, err := s.authManager.ValidateAndParseToken(ctx, jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } - claims := s.jwtClaimsExtractor.FromToken(token) + // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims) + _, _, err = s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } - if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil { + userAuth, err = s.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, token) + if err != nil { return "", status.Error(codes.PermissionDenied, err.Error()) } - return claims.UserId, nil + err = s.accountManager.SyncUserJWTGroups(ctx, userAuth) + if err != nil { + log.WithContext(ctx).Errorf("gRPC server failed to sync user JWT groups: %s", err) + } + + return userAuth.UserId, nil } func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 7ce09fffaff..2b87c5f2542 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -11,9 +11,9 @@ import ( "github.com/netbirdio/management-integrations/integrations" s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/handlers/accounts" "github.com/netbirdio/netbird/management/server/http/handlers/dns" "github.com/netbirdio/netbird/management/server/http/handlers/events" @@ -26,7 +26,6 @@ import ( "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" @@ -36,55 +35,51 @@ import ( const apiPrefix = "/api" // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { - claimsExtractor := jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ) +func NewAPIHandler( + ctx context.Context, + accountManager s.AccountManager, + networksManager nbnetworks.Manager, + resourceManager resources.Manager, + routerManager routers.Manager, + groupsManager nbgroups.Manager, + LocationManager geolocation.Geolocation, + authManager auth.Manager, + appMetrics telemetry.AppMetrics, + config *s.Config, + integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { authMiddleware := middleware.NewAuthMiddleware( - accountManager.GetPATInfo, - jwtValidator.ValidateAndParse, - accountManager.MarkPATUsed, - accountManager.CheckUserAccessByJWTGroups, - claimsExtractor, - authCfg.Audience, - authCfg.UserIDClaim, + authManager, + accountManager.GetAccountIDFromUserAuth, + accountManager.SyncUserJWTGroups, ) corsMiddleware := cors.AllowAll() - claimsExtractor = jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ) - - acMiddleware := middleware.NewAccessControl( - authCfg.Audience, - authCfg.UserIDClaim, - accountManager.GetUser) + acMiddleware := middleware.NewAccessControl(accountManager.GetUserFromUserAuth) rootRouter := mux.NewRouter() metricsMiddleware := appMetrics.HTTPMiddleware() prefix := apiPrefix router := rootRouter.PathPrefix(prefix).Subrouter() + router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) - if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } - accounts.AddEndpoints(accountManager, authCfg, router) - peers.AddEndpoints(accountManager, authCfg, router) - users.AddEndpoints(accountManager, authCfg, router) - setup_keys.AddEndpoints(accountManager, authCfg, router) - policies.AddEndpoints(accountManager, LocationManager, authCfg, router) - groups.AddEndpoints(accountManager, authCfg, router) - routes.AddEndpoints(accountManager, authCfg, router) - dns.AddEndpoints(accountManager, authCfg, router) - events.AddEndpoints(accountManager, authCfg, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router) + accounts.AddEndpoints(accountManager, router) + peers.AddEndpoints(accountManager, router) + users.AddEndpoints(accountManager, router) + setup_keys.AddEndpoints(accountManager, router) + policies.AddEndpoints(accountManager, LocationManager, router) + groups.AddEndpoints(accountManager, router) + routes.AddEndpoints(accountManager, router) + dns.AddEndpoints(accountManager, router) + events.AddEndpoints(accountManager, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router) return rootRouter, nil } diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index a23628cdcc4..1f3dde717f8 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -9,47 +9,42 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that handles the server.Account HTTP endpoints type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - accountsHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + accountsHandler := newHandler(accountManager) router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -62,15 +57,17 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + _, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) accountID := vars["accountId"] + // @todo additional check for account access, consider impersonated if len(accountID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w) return @@ -125,15 +122,21 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { // deleteAccount is a HTTP DELETE handler to delete an account func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + vars := mux.Vars(r) targetAccountID := vars["accountId"] + // @todo additional check for account access, consider impersonated if len(targetAccountID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w) return } - err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId) + err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index e8a599863ce..a8d57a13fd7 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -13,19 +13,16 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) -func initAccountsTestData(account *types.Account, admin *types.User) *handler { +func initAccountsTestData(account *types.Account) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return account.Id, admin.Id, nil - }, GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil }, @@ -44,15 +41,6 @@ func initAccountsTestData(account *types.Account, admin *types.User) *handler { return accCopy, nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_account", - } - }), - ), } } @@ -75,7 +63,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { PeerLoginExpiration: time.Hour, RegularUsersViewBlocked: true, }, - }, adminUser) + }) tt := []struct { name string @@ -191,6 +179,11 @@ func TestAccounts_AccountsHandler(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: adminUser.Id, + AccountId: accountID, + Domain: "hotmail.com", + }) router := mux.NewRouter() router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 112eee1797b..6ff938369ec 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -8,51 +8,44 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account type dnsSettingsHandler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - addDNSSettingEndpoint(accountManager, authCfg, router) - addDNSNameserversEndpoint(accountManager, authCfg, router) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + addDNSSettingEndpoint(accountManager, router) + addDNSNameserversEndpoint(accountManager, router) } -func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg) +func addDNSSettingEndpoint(accountManager server.AccountManager, router *mux.Router) { + dnsSettingsHandler := newDNSSettingsHandler(accountManager) router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") } // newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler -func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler { - return &dnsSettingsHandler{ - accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), - } +func newDNSSettingsHandler(accountManager server.AccountManager) *dnsSettingsHandler { + return &dnsSettingsHandler{accountManager: accountManager} } // getDNSSettings returns the DNS settings for the account func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -68,13 +61,14 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque // updateDNSSettings handles update to DNS settings of an account func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PutApiDnsSettingsJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index 9ca1dc03253..ca81adf4366 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -17,7 +17,8 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -52,19 +53,7 @@ func initDNSSettingsTestData() *dnsSettingsHandler { } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil - }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: testDNSSettingsAccountID, - } - }), - ), } } @@ -118,6 +107,11 @@ func TestDNSSettingsHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, + AccountId: testingDNSSettingsAccount.Id, + Domain: testingDNSSettingsAccount.Domain, + }) router := mux.NewRouter() router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") diff --git a/management/server/http/handlers/dns/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go index 09047e231af..33d07047737 100644 --- a/management/server/http/handlers/dns/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -10,21 +10,19 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) // nameserversHandler is the nameserver group handler of the account type nameserversHandler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - nameserversHandler := newNameserversHandler(accountManager, authCfg) +func addDNSNameserversEndpoint(accountManager server.AccountManager, router *mux.Router) { + nameserversHandler := newNameserversHandler(accountManager) router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") @@ -33,26 +31,21 @@ func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg con } // newNameserversHandler returns a new instance of nameserversHandler handler -func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler { - return &nameserversHandler{ - accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), - } +func newNameserversHandler(accountManager server.AccountManager) *nameserversHandler { + return &nameserversHandler{accountManager: accountManager} } // getAllNameservers returns the list of nameserver groups for the account func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -69,13 +62,14 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re // createNameserverGroup handles nameserver group creation request func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PostApiDnsNameserversJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -102,13 +96,14 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt // updateNameserverGroup handles update to a nameserver group identified by a given ID func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) @@ -153,13 +148,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt // deleteNameserverGroup handles nameserver group deletion request func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) @@ -177,14 +173,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt // getNameserverGroup handles a nameserver group Get request identified by ID func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index c6561e4d826..45283bc377a 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -18,7 +18,8 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -81,19 +82,7 @@ func initNameserversTestData() *nameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: testNSGroupAccountID, - } - }), - ), } } @@ -204,6 +193,11 @@ func TestNameserversHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + AccountId: testNSGroupAccountID, + Domain: "hotmail.com", + }) router := mux.NewRouter() router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") diff --git a/management/server/http/handlers/events/events_handler.go b/management/server/http/handlers/events/events_handler.go index 62da5953524..0fb2295a839 100644 --- a/management/server/http/handlers/events/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -10,44 +10,37 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" ) // handler HTTP handler type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - eventsHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + eventsHandler := newHandler(accountManager) router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") } // newHandler creates a new events handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { - return &handler{ - accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), - } +func newHandler(accountManager server.AccountManager) *handler { + return &handler{accountManager: accountManager} } // getAllEvents list of the given account func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index 17478aba351..5cdb4739aa7 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -13,9 +13,10 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" ) @@ -29,22 +30,10 @@ func initEventsTestData(account string, events ...*activity.Event) *handler { } return []*activity.Event{}, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { return make([]*types.UserInfo, 0), nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_account", - } - }), - ), } } @@ -199,6 +188,11 @@ func TestEvents_GetEvents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_account", + }) router := mux.NewRouter() router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index ec635a35800..040c08b87dd 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -7,24 +7,23 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + nbcontext "github.com/netbirdio/netbird/management/server/context" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns groups of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - groupsHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + groupsHandler := newHandler(accountManager) router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") @@ -33,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // newHandler creates a new groups handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllGroups list for the account func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountID, userID := userAuth.AccountId, userAuth.UserId groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) if err != nil { @@ -75,13 +70,14 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { // updateGroup handles update to a group identified by a given ID func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) groupID, ok := vars["groupId"] if !ok { @@ -164,13 +160,14 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { // createGroup handles group creation request func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PostApiGroupsJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -223,13 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { // deleteGroup handles group deletion request func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) @@ -253,12 +251,13 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { // getGroup returns a group func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + + accountID, userID := userAuth.AccountId, userAuth.UserId groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 0668982f31a..c4b9e46ab6d 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -18,9 +18,9 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -59,9 +59,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return group, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil @@ -87,15 +84,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -134,6 +122,11 @@ func TestGetGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") @@ -255,6 +248,11 @@ func TestWriteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/groups", p.createGroup).Methods("POST") @@ -332,7 +330,11 @@ func TestDeleteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.ServeHTTP(recorder, req) diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index 316b936115b..f4c884290fb 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -9,11 +9,10 @@ import ( "github.com/gorilla/mux" s "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" @@ -30,16 +29,14 @@ type handler struct { routerManager routers.Manager accountManager s.AccountManager - groupsManager groups.Manager - extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - claimsExtractor *jwtclaims.ClaimsExtractor + groupsManager groups.Manager } -func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { - addRouterEndpoints(routerManager, extractFromToken, authCfg, router) - addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, router *mux.Router) { + addRouterEndpoints(routerManager, router) + addResourceEndpoints(resourceManager, groupsManager, router) - networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg) + networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager) router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") @@ -47,29 +44,25 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") } -func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { +func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager) *handler { return &handler{ - networksManager: networksManager, - resourceManager: resourceManager, - routerManager: routerManager, - groupsManager: groupsManager, - accountManager: accountManager, - extractFromToken: extractFromToken, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), + networksManager: networksManager, + resourceManager: resourceManager, + routerManager: routerManager, + groupsManager: groupsManager, + accountManager: accountManager, } } func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -104,12 +97,12 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { } func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId var req api.NetworkRequest err = json.NewDecoder(r.Body).Decode(&req) @@ -140,12 +133,12 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { } func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) networkID := vars["networkId"] @@ -178,13 +171,13 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { } func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -228,13 +221,13 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { } func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index f2dc8e3b86d..fba7026e8df 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -1,30 +1,26 @@ package networks import ( - "context" "encoding/json" "net/http" "github.com/gorilla/mux" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" ) type resourceHandler struct { - resourceManager resources.Manager - groupsManager groups.Manager - extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - claimsExtractor *jwtclaims.ClaimsExtractor + resourceManager resources.Manager + groupsManager groups.Manager } -func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { - resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg) +func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) { + resourceHandler := newResourceHandler(resourcesManager, groupsManager) router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") @@ -33,26 +29,21 @@ func addResourceEndpoints(resourcesManager resources.Manager, groupsManager grou router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") } -func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler { +func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler { return &resourceHandler{ - resourceManager: resourceManager, - groupsManager: groupsManager, - extractFromToken: extractFromToken, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), + resourceManager: resourceManager, + groupsManager: groupsManager, } } func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId networkID := mux.Vars(r)["networkId"] resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) if err != nil { @@ -76,13 +67,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, resourcesResponse) } func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -106,13 +98,14 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt } func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.NetworkResourceRequest err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -144,13 +137,13 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) } func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) @@ -171,13 +164,13 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { } func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId var req api.NetworkResourceRequest err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -209,12 +202,12 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) } func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index 7ca95d902f9..f98da49661f 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -1,28 +1,24 @@ package networks import ( - "context" "encoding/json" "net/http" "github.com/gorilla/mux" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" ) type routersHandler struct { - routersManager routers.Manager - extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - claimsExtractor *jwtclaims.ClaimsExtractor + routersManager routers.Manager } -func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { - routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg) +func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) { + routersHandler := newRoutersHandler(routersManager) router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") @@ -30,25 +26,21 @@ func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ct router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") } -func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler { +func newRoutersHandler(routersManager routers.Manager) *routersHandler { return &routersHandler{ - routersManager: routersManager, - extractFromToken: extractFromToken, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), + routersManager: routersManager, } } func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + networkID := mux.Vars(r)["networkId"] routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) if err != nil { @@ -65,13 +57,14 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + networkID := mux.Vars(r)["networkId"] var req api.NetworkRouterRequest err = json.NewDecoder(r.Body).Decode(&req) @@ -96,13 +89,14 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) @@ -115,13 +109,14 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.NetworkRouterRequest err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -146,13 +141,13 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { } func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.extractFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index cdd8026f257..1b8fb7e4f2a 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,11 +10,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -22,12 +21,11 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - peersHandler := NewHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + peersHandler := NewHandler(accountManager) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -35,13 +33,9 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // NewHandler creates a new peers Handler -func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler { +func NewHandler(accountManager server.AccountManager) *Handler { return &Handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } @@ -149,12 +143,13 @@ func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peer // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -179,13 +174,14 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -230,13 +226,14 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPee // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 16065a677a7..63b8c0ab3bd 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -15,8 +15,8 @@ import ( "github.com/gorilla/mux" "golang.org/x/exp/maps" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" @@ -25,16 +25,13 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -type ctxKey string - const ( testPeerID = "test_peer" noUpdateChannelTestPeerID = "no-update-channel" - adminUser = "admin_user" - regularUser = "regular_user" - serviceUser = "service_user" - userIDKey ctxKey = "user_id" + adminUser = "admin_user" + regularUser = "regular_user" + serviceUser = "service_user" ) func initTestMetaData(peers ...*nbpeer.Peer) *Handler { @@ -146,9 +143,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { return account, nil }, @@ -167,16 +161,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { return ok }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - userID := r.Context().Value(userIDKey).(string) - return jwtclaims.AuthorizationClaims{ - UserId: userID, - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -267,8 +251,11 @@ func TestGetPeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - ctx := context.WithValue(context.Background(), userIDKey, "admin_user") - req = req.WithContext(ctx) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "admin_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") @@ -412,8 +399,11 @@ func TestGetAccessiblePeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) - ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID) - req = req.WithContext(ctx) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: tc.callerUserID, + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET") diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index fc5839baaab..fbdc324d650 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -13,9 +13,9 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" @@ -43,23 +43,11 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler { return &geolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return types.NewAdminUser(id), nil }, }, geolocationManager: geo, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -112,6 +100,11 @@ func TestGetCitiesByCountry(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") @@ -200,6 +193,11 @@ func TestGetAllCountries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index 161d974022a..c4868f879d0 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -7,11 +7,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) @@ -23,24 +22,19 @@ var ( type geolocationsHandler struct { accountManager server.AccountManager geolocationManager geolocation.Geolocation - claimsExtractor *jwtclaims.ClaimsExtractor } -func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { - locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) +func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) { + locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager) router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") } // newGeolocationsHandlerHandler creates a new Geolocations handler -func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *geolocationsHandler { return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } @@ -104,12 +98,13 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http. } func (l *geolocationsHandler) authenticateUser(r *http.Request) error { - claims := l.claimsExtractor.FromRequestContext(r) - _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { return err } + _, userID := userAuth.AccountId, userAuth.UserId + user, err := l.accountManager.GetUserByID(r.Context(), userID) if err != nil { return err diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index a748e73b8ed..63fc8a03b87 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -8,51 +8,46 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { - policiesHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) { + policiesHandler := newHandler(accountManager) router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") - addPostureCheckEndpoint(accountManager, locationManager, authCfg, router) + addPostureCheckEndpoint(accountManager, locationManager, router) } // newHandler creates a new policies handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllPolicies list for the account func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -80,13 +75,14 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { // updatePolicy handles update to a policy identified by a given ID func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -105,13 +101,14 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { // createPolicy handles policy creation request func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + h.savePolicy(w, r, accountID, userID, "") } @@ -306,13 +303,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s // deletePolicy handles policy deletion request func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -330,13 +327,14 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { // getPolicy handles a group Get request identified by ID func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index 8fbf84d4b09..6450295eb3d 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -44,9 +44,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler { GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { return []*types.Group{{ID: "F"}, {ID: "G"}}, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { user := types.NewAdminUser(userID) return &types.Account{ @@ -65,15 +62,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler { }, nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -115,6 +103,11 @@ func TestPoliciesGetPolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") @@ -274,6 +267,11 @@ func TestPoliciesWritePolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index ce0d4878c92..e6e58da58c6 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -7,11 +7,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) @@ -20,40 +19,35 @@ import ( type postureChecksHandler struct { accountManager server.AccountManager geolocationManager geolocation.Geolocation - claimsExtractor *jwtclaims.ClaimsExtractor } -func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { - postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) { + postureCheckHandler := newPostureChecksHandler(accountManager, locationManager) router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") - addLocationsEndpoint(accountManager, locationManager, authCfg, router) + addLocationsEndpoint(accountManager, locationManager, router) } // newPostureChecksHandler creates a new PostureChecks handler -func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *postureChecksHandler { return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllPostureChecks list for the account func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -70,13 +64,14 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt // updatePostureCheck handles update to a posture check identified by a given ID func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -95,25 +90,26 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http // createPostureCheck handles posture check creation request func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + p.savePostureChecks(w, r, accountID, userID, "") } // getPostureCheck handles a posture check Get request identified by ID func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -132,13 +128,13 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re // deletePostureCheck handles posture check deletion request func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { - claims := p.claimsExtractor.FromRequestContext(r) - accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index 237687fd4a4..e3844caa206 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -14,9 +14,9 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -66,20 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH } return accountPostureChecks, nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, }, geolocationManager: &geolocation.Mock{}, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: "test_id", - } - }), - ), } } @@ -187,6 +175,11 @@ func TestGetPostureCheck(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) router := mux.NewRouter() router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") @@ -835,6 +828,11 @@ func TestPostureCheckUpdate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) defaultHandler := *p if tc.setupHandlerFunc != nil { diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index a29ba45629d..a0fbfda53ef 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -13,10 +13,9 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" ) @@ -26,12 +25,11 @@ const failedToConvertRoute = "failed to convert route to response: %v" // handler is the routes handler of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - routesHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + routesHandler := newHandler(accountManager) router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") @@ -40,25 +38,22 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // newHandler returns a new instance of routes handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllRoutes returns the list of routes for the account func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -79,13 +74,14 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { // createRoute handles route creation request func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + var req api.PostApiRoutesJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -176,13 +172,13 @@ func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { // updateRoute handles update to a route identified by a given ID func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) routeID := vars["routeId"] if len(routeID) == 0 { @@ -269,13 +265,13 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { // deleteRoute handles route deletion request func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) @@ -293,13 +289,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { // getRoute handles a route Get request identified by ID func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId + routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 4064ec361f7..a58eadeef89 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -16,12 +16,10 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/domain" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -60,32 +58,6 @@ var baseExistingRoute = &route.Route{ Groups: []string{existingGroupID}, } -var testingAccount = &types.Account{ - Id: testAccountID, - Domain: "hotmail.com", - Peers: map[string]*nbpeer.Peer{ - existingPeerID: { - Key: existingPeerKey, - IP: netip.MustParseAddr(existingPeerIP1).AsSlice(), - ID: existingPeerID, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - }, - }, - nonLinuxExistingPeerID: { - Key: nonLinuxExistingPeerID, - IP: netip.MustParseAddr(existingPeerIP2).AsSlice(), - ID: nonLinuxExistingPeerID, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "darwin", - }, - }, - }, - Users: map[string]*types.User{ - "test_user": types.NewAdminUser("test_user"), - }, -} - func initRoutesTestData() *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ @@ -150,20 +122,7 @@ func initRoutesTestData() *handler { } return nil }, - GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - // return testingAccount, testingAccount.Users["test_user"], nil - return testingAccount.Id, testingAccount.Users["test_user"].Id, nil - }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", - AccountId: testAccountID, - } - }), - ), } } @@ -526,6 +485,11 @@ func TestRoutesHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: testAccountID, + }) router := mux.NewRouter() router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 67e2969016d..36282cb69d3 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -9,22 +9,20 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - keysHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + keysHandler := newHandler(accountManager) router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") @@ -33,25 +31,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, } // newHandler creates a new setup key handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // createSetupKey is a POST requests that creates a new SetupKey func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId req := &api.PostApiSetupKeysJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -102,12 +96,12 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { // getSetupKey is a GET request to get a SetupKey by ID func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) keyID := vars["keyId"] @@ -127,13 +121,13 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { // updateSetupKey is a PUT request to update server.SetupKey func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -168,13 +162,13 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { // getAllSetupKeys is a GET request that returns a list of SetupKey func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -190,13 +184,13 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { } func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index f56227c10dc..74d4473990a 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -14,8 +14,8 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -28,14 +28,9 @@ const ( notFoundSetupKeyID = "notFoundSetupKeyID" ) -func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, - user *types.User, -) *handler { +func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, ) (*types.SetupKey, error) { @@ -75,15 +70,6 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe return status.Errorf(status.NotFound, "key %s not found", keyID) }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: user.Id, - Domain: "hotmail.com", - AccountId: "testAccountId", - } - }), - ), } } @@ -170,12 +156,17 @@ func TestSetupKeysHandlers(t *testing.T) { }, } - handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser) + handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: adminUser.Id, + Domain: "hotmail.com", + AccountId: "testAccountId", + }) router := mux.NewRouter() router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 7b93d2ae116..84fbef93e63 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -7,22 +7,20 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account type patHandler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - tokenHandler := newPATsHandler(accountManager, authCfg) +func addUsersTokensEndpoint(accountManager server.AccountManager, router *mux.Router) { + tokenHandler := newPATsHandler(accountManager) router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") @@ -30,25 +28,21 @@ func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg config } // newPATsHandler creates a new patHandler HTTP handler -func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler { +func newPATsHandler(accountManager server.AccountManager) *patHandler { return &patHandler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } // getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(userID) == 0 { @@ -72,13 +66,13 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { // getToken is HTTP GET handler that returns a personal access token for the given user func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -103,13 +97,13 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { // createToken is HTTP POST handler that creates a personal access token for the given user func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -135,13 +129,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { // deleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 9388067a49c..6593de64a06 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -12,11 +12,12 @@ import ( "github.com/google/go-cmp/cmp" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/util" + + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -77,10 +78,6 @@ func initPATTestData() *patHandler { PersonalAccessToken: types.PersonalAccessToken{}, }, nil }, - - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return claims.AccountId, claims.UserId, nil - }, DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { return status.Errorf(status.NotFound, "account with ID %s not found", accountID) @@ -115,15 +112,6 @@ func initPATTestData() *patHandler { return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: existingUserID, - Domain: testDomain, - AccountId: existingAccountID, - } - }), - ), } } @@ -185,6 +173,11 @@ func TestTokenHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) router := mux.NewRouter() router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index 7380dd97e8a..3869f21f071 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -9,39 +9,33 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/jwtclaims" + nbcontext "github.com/netbirdio/netbird/management/server/context" ) // handler is a handler that returns users of the account type handler struct { - accountManager server.AccountManager - claimsExtractor *jwtclaims.ClaimsExtractor + accountManager server.AccountManager } -func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { - userHandler := newHandler(accountManager, authCfg) +func AddEndpoints(accountManager server.AccountManager, router *mux.Router) { + userHandler := newHandler(accountManager) router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") - addUsersTokensEndpoint(accountManager, authCfg, router) + addUsersTokensEndpoint(accountManager, router) } // newHandler creates a new UsersHandler HTTP handler -func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { +func newHandler(accountManager server.AccountManager) *handler { return &handler{ accountManager: accountManager, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ), } } @@ -52,13 +46,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -103,7 +97,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) } // deleteUser is a DELETE request to delete a user @@ -113,13 +107,13 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -143,12 +137,12 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId req := &api.PostApiUsersJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) @@ -184,7 +178,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) } // getAllUsers returns a list of users of the account this user belongs to. @@ -195,13 +189,13 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -216,7 +210,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { continue } if serviceUser == "" { - users = append(users, toUserResponse(d, claims.UserId)) + users = append(users, toUserResponse(d, userID)) continue } @@ -227,7 +221,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { return } if includeServiceUser == d.IsServiceUser { - users = append(users, toUserResponse(d, claims.UserId)) + users = append(users, toUserResponse(d, userID)) } } @@ -242,12 +236,12 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { return } - claims := h.claimsExtractor.FromRequestContext(r) - accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) if err != nil { util.WriteError(r.Context(), err, w) return } + accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index 90081830a0d..b5d08937f47 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -64,9 +64,6 @@ var usersTestAccount = &types.Account{ func initUsersTestData() *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { - return usersTestAccount.Id, claims.UserId, nil - }, GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return usersTestAccount.Users[id], nil }, @@ -127,15 +124,6 @@ func initUsersTestData() *handler { return nil }, }, - claimsExtractor: jwtclaims.NewClaimsExtractor( - jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { - return jwtclaims.AuthorizationClaims{ - UserId: existingUserID, - Domain: testDomain, - AccountId: existingAccountID, - } - }), - ), } } @@ -158,6 +146,11 @@ func TestGetUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) userHandler.getAllUsers(recorder, req) @@ -263,6 +256,11 @@ func TestUpdateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) router := mux.NewRouter() router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") @@ -355,6 +353,11 @@ func TestCreateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) userHandler.createUser(rr, req) @@ -399,6 +402,12 @@ func TestInviteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) + rr := httptest.NewRecorder() userHandler.inviteUser(rr, req) @@ -452,6 +461,12 @@ func TestDeleteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) + req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + UserId: existingUserID, + Domain: testDomain, + AccountId: existingAccountID, + }) + rr := httptest.NewRecorder() userHandler.deleteUser(rr, req) diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index c5bdf5fe7f1..4ed90f47b42 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -7,30 +7,24 @@ import ( log "github.com/sirupsen/logrus" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" - - "github.com/netbirdio/netbird/management/server/jwtclaims" ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) +type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { - claimsExtract jwtclaims.ClaimsExtractor - getUser GetUser + getUser GetUser } // NewAccessControl instance constructor -func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl { +func NewAccessControl(getUser GetUser) *AccessControl { return &AccessControl{ - claimsExtract: *jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(audience), - jwtclaims.WithUserIDClaim(userIDClaim), - ), getUser: getUser, } } @@ -45,12 +39,16 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return } - claims := a.claimsExtract.FromRequestContext(r) + userAuth, err := nbcontext.GetUserAuthFromRequest(r) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w) + } - user, err := a.getUser(r.Context(), claims) + user, err := a.getUser(r.Context(), userAuth) if err != nil { - log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err) - util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w) + log.WithContext(r.Context()).Errorf("failed to get user: %s", err) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w) return } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index dcf73259a80..0ec1f581f84 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -8,67 +8,41 @@ import ( "strings" "time" - "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" - nbContext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/types" ) -// GetAccountInfoFromPATFunc function -type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) - -// ValidateAndParseTokenFunc function -type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) - -// MarkPATUsedFunc function -type MarkPATUsedFunc func(ctx context.Context, token string) error - -// CheckUserAccessByJWTGroupsFunc function -type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error +type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) +type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - getAccountInfoFromPAT GetAccountInfoFromPATFunc - validateAndParseToken ValidateAndParseTokenFunc - markPATUsed MarkPATUsedFunc - checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc - claimsExtractor *jwtclaims.ClaimsExtractor - audience string - userIDClaim string + authManager auth.Manager + ensureAccount EnsureAccountFunc + syncUserJWTGroups SyncUserJWTGroupsFunc } -const ( - userProperty = "user" -) - // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, - markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, - audience string, userIdClaim string) *AuthMiddleware { - if userIdClaim == "" { - userIdClaim = jwtclaims.UserIDClaim - } - +func NewAuthMiddleware( + authManager auth.Manager, + ensureAccount EnsureAccountFunc, + syncUserJWTGroups SyncUserJWTGroupsFunc, +) *AuthMiddleware { return &AuthMiddleware{ - getAccountInfoFromPAT: getAccountInfoFromPAT, - validateAndParseToken: validateAndParseToken, - markPATUsed: markPATUsed, - checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, - claimsExtractor: claimsExtractor, - audience: audience, - userIDClaim: userIdClaim, + authManager: authManager, + ensureAccount: ensureAccount, + syncUserJWTGroups: syncUserJWTGroups, } } // Handler method of the middleware which authenticates a user either by JWT claims or by PAT func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if bypass.ShouldBypass(r.URL.Path, h, w, r) { return } @@ -84,108 +58,106 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { switch authType { case "bearer": - err := m.checkJWTFromRequest(w, r, auth) + request, err := m.checkJWTFromRequest(r, auth) if err != nil { - log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error()) + log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } + + h.ServeHTTP(w, request) case "token": - err := m.checkPATFromRequest(w, r, auth) + request, err := m.checkPATFromRequest(r, auth) if err != nil { - log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error()) + log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } + h.ServeHTTP(w, request) default: util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return } - claims := m.claimsExtractor.FromRequestContext(r) - //nolint - ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId) - //nolint - ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId) - h.ServeHTTP(w, r.WithContext(ctx)) }) } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) { token, err := getTokenFromJWTRequest(auth) // If an error occurs, call the error handler and return an error if err != nil { - return fmt.Errorf("Error extracting token: %w", err) + return r, fmt.Errorf("error extracting token: %w", err) } - validatedToken, err := m.validateAndParseToken(r.Context(), token) + ctx := r.Context() + + userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) if err != nil { - return err + return r, err } - if validatedToken == nil { - return nil + if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { + userAuth.AccountId = impersonate[0] + userAuth.IsChild = ok } - if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil { - return err + // we need to call this method because if user is new, we will automatically add it to existing or create a new account + _, _, err = m.ensureAccount(ctx, userAuth) + if err != nil { + return r, err } - // If we get here, everything worked and we can set the - // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint - // Update the current request with the new context information. - *r = *newRequest - return nil -} + userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) + if err != nil { + return r, err + } -// verifyUserAccess checks if a user, based on a validated JWT token, -// is allowed access, particularly in cases where the admin enabled JWT -// group propagation and designated certain groups with access permissions. -func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error { - authClaims := m.claimsExtractor.FromToken(validatedToken) - return m.checkUserAccessByJWTGroups(ctx, authClaims) + err = m.syncUserJWTGroups(ctx, userAuth) + if err != nil { + log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err) + } + + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) { token, err := getTokenFromPATRequest(auth) if err != nil { - return fmt.Errorf("error extracting token: %w", err) + return r, fmt.Errorf("error extracting token: %w", err) } - user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token) + ctx := r.Context() + user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { - return fmt.Errorf("invalid Token: %w", err) + return r, fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.GetExpirationDate()) { - return fmt.Errorf("token expired") + return r, fmt.Errorf("token expired") } - err = m.markPATUsed(r.Context(), pat.ID) + err = m.authManager.MarkPATUsed(ctx, pat.ID) if err != nil { - return err + return r, err + } + + userAuth := nbcontext.UserAuth{ + UserId: user.Id, + AccountId: user.AccountID, + Domain: accDomain, + DomainCategory: accCategory, + IsPAT: true, } - claimMaps := jwt.MapClaims{} - claimMaps[m.userIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory - claimMaps[jwtclaims.IsToken] = true - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint - // Update the current request with the new context information. - *r = *newRequest - return nil + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } // getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts // the JWT token from the Authorization header. func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") + return "", errors.New("authorization header format must be Bearer {token}") } return authHeaderParts[1], nil @@ -195,7 +167,7 @@ func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { // the PAT token from the Authorization header. func getTokenFromPATRequest(authHeaderParts []string) (string, error) { if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" { - return "", errors.New("Authorization header format must be Token {token}") + return "", errors.New("authorization header format must be Token {token}") } return authHeaderParts[1], nil diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index c1686ed440e..c579f47403e 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -9,10 +9,13 @@ import ( "time" "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/management/server/auth" + nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/types" ) @@ -58,17 +61,21 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use return nil, nil, "", "", fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { if token == JWT { - return &jwt.Token{ - Claims: jwt.MapClaims{ - userIDClaim: userID, - audience + jwtclaims.AccountIDSuffix: accountID, + return nbcontext.UserAuth{ + UserId: userID, + AccountId: accountID, }, - Valid: true, - }, nil + &jwt.Token{ + Claims: jwt.MapClaims{ + userIDClaim: userID, + audience + nbjwt.AccountIDSuffix: accountID, + }, + Valid: true, + }, nil } - return nil, fmt.Errorf("JWT invalid") + return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid") } func mockMarkPATUsed(_ context.Context, token string) error { @@ -78,16 +85,16 @@ func mockMarkPATUsed(_ context.Context, token string) error { return fmt.Errorf("Should never get reached") } -func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error { - if testAccount.Id != claims.AccountId { - return fmt.Errorf("account with id %s does not exist", claims.AccountId) +func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { + if testAccount.Id != userAuth.AccountId { + return userAuth, fmt.Errorf("account with id %s does not exist", userAuth.AccountId) } - if _, ok := testAccount.Users[claims.UserId]; !ok { - return fmt.Errorf("user with id %s does not exist", claims.UserId) + if _, ok := testAccount.Users[userAuth.UserId]; !ok { + return userAuth, fmt.Errorf("user with id %s does not exist", userAuth.UserId) } - return nil + return userAuth, nil } func TestAuthMiddleware_Handler(t *testing.T) { @@ -158,22 +165,24 @@ func TestAuthMiddleware_Handler(t *testing.T) { } nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // do nothing + }) - claimsExtractor := jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(audience), - jwtclaims.WithUserIDClaim(userIDClaim), - ) + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetAccountInfoFromPATFunc: mockGetAccountInfoFromPAT, + } authMiddleware := NewAuthMiddleware( - mockGetAccountInfoFromPAT, - mockValidateAndParseToken, - mockMarkPATUsed, - mockCheckUserAccessByJWTGroups, - claimsExtractor, - audience, - userIDClaim, + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -195,6 +204,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { result := rec.Result() defer result.Body.Close() + if result.StatusCode != tc.expectedStatusCode { t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode) } diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index 006d5679c00..8e01f7b7f3b 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -13,17 +13,15 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" nbhttp "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/http/configs" - "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" @@ -32,6 +30,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) const ( @@ -114,12 +113,13 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve if err != nil { t.Fatalf("Failed to create manager: %v", err) } - + // @note only PAT's in store will be authed + authManager := auth.NewManager(store, "", "", "", "", []string{}, false) networksManagerMock := networks.NewManagerMock() resourcesManagerMock := resources.NewManagerMock() routersManagerMock := routers.NewManagerMock() groupsManagerMock := groups.NewManagerMock() - apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock) + apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManager, metrics, &server.Config{}, validatorMock) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go index 2527acbe329..e6b3adc3603 100644 --- a/management/server/jwtclaims/claims.go +++ b/management/server/jwtclaims/claims.go @@ -6,7 +6,7 @@ import ( "github.com/golang-jwt/jwt" ) -// AuthorizationClaims stores authorization information from JWTs +// deprecated, use UserAuth instead type AuthorizationClaims struct { UserId string AccountId string diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go deleted file mode 100644 index eccd7c9e7c9..00000000000 --- a/management/server/jwtclaims/extractor_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package jwtclaims - -import ( - "context" - "net/http" - "testing" - "time" - - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/require" -) - -func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience string) *http.Request { - t.Helper() - const layout = "2006-01-02T15:04:05.999Z" - - claimMaps := jwt.MapClaims{} - if claims.UserId != "" { - claimMaps[UserIDClaim] = claims.UserId - } - if claims.AccountId != "" { - claimMaps[audience+AccountIDSuffix] = claims.AccountId - } - if claims.Domain != "" { - claimMaps[audience+DomainIDSuffix] = claims.Domain - } - if claims.DomainCategory != "" { - claimMaps[audience+DomainCategorySuffix] = claims.DomainCategory - } - if claims.LastLogin != (time.Time{}) { - claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout) - } - - if claims.Invited { - claimMaps[audience+Invited] = true - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) - require.NoError(t, err, "creating testing request failed") - testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint - - return testRequest -} - -func TestExtractClaimsFromRequestContext(t *testing.T) { - type test struct { - name string - inputAuthorizationClaims AuthorizationClaims - inputAudiance string - testingFunc require.ComparisonAssertionFunc - expectedMSG string - } - - const layout = "2006-01-02T15:04:05.999Z" - lastLogin, _ := time.Parse(layout, "2023-08-17T09:30:40.465Z") - - testCase1 := test{ - name: "All Claim Fields", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Domain: "test.com", - AccountId: "testAcc", - LastLogin: lastLogin, - DomainCategory: "public", - Invited: true, - Raw: jwt.MapClaims{ - "https://login/wt_account_domain": "test.com", - "https://login/wt_account_domain_category": "public", - "https://login/wt_account_id": "testAcc", - "https://login/nb_last_login": lastLogin.Format(layout), - "sub": "test", - "https://login/" + Invited: true, - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase2 := test{ - name: "Domain Is Empty", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - AccountId: "testAcc", - Raw: jwt.MapClaims{ - "https://login/wt_account_id": "testAcc", - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase3 := test{ - name: "Account ID Is Empty", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Domain: "test.com", - Raw: jwt.MapClaims{ - "https://login/wt_account_domain": "test.com", - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase4 := test{ - name: "Category Is Empty", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Domain: "test.com", - AccountId: "testAcc", - Raw: jwt.MapClaims{ - "https://login/wt_account_domain": "test.com", - "https://login/wt_account_id": "testAcc", - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - testCase5 := test{ - name: "Only User ID Is set", - inputAudiance: "https://login/", - inputAuthorizationClaims: AuthorizationClaims{ - UserId: "test", - Raw: jwt.MapClaims{ - "sub": "test", - }, - }, - testingFunc: require.EqualValues, - expectedMSG: "extracted claims should match input claims", - } - - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { - t.Run(testCase.name, func(t *testing.T) { - request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance) - - extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance)) - extractedClaims := extractor.FromRequestContext(request) - - testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG) - }) - } -} - -func TestExtractClaimsSetOptions(t *testing.T) { - t.Helper() - type test struct { - name string - extractor *ClaimsExtractor - check func(t *testing.T, c test) - } - - testCase1 := test{ - name: "No custom options", - extractor: NewClaimsExtractor(), - check: func(t *testing.T, c test) { - t.Helper() - if c.extractor.authAudience != "" { - t.Error("audience should be empty") - return - } - if c.extractor.userIDClaim != UserIDClaim { - t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) - return - } - if c.extractor.FromRequestContext == nil { - t.Error("from request context should not be nil") - return - } - }, - } - - testCase2 := test{ - name: "Custom audience", - extractor: NewClaimsExtractor(WithAudience("https://login/")), - check: func(t *testing.T, c test) { - t.Helper() - if c.extractor.authAudience != "https://login/" { - t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience) - return - } - }, - } - - testCase3 := test{ - name: "Custom user id claim", - extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")), - check: func(t *testing.T, c test) { - t.Helper() - if c.extractor.userIDClaim != "customUserId" { - t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim) - return - } - }, - } - - testCase4 := test{ - name: "Custom extractor from request context", - extractor: NewClaimsExtractor( - WithFromRequestContext(func(r *http.Request) AuthorizationClaims { - return AuthorizationClaims{ - UserId: "testCustomRequest", - } - })), - check: func(t *testing.T, c test) { - t.Helper() - claims := c.extractor.FromRequestContext(&http.Request{}) - if claims.UserId != "testCustomRequest" { - t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId) - return - } - }, - } - - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { - t.Run(testCase.name, func(t *testing.T) { - testCase.check(t, testCase) - }) - } -} diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go deleted file mode 100644 index 79e59e76feb..00000000000 --- a/management/server/jwtclaims/jwtValidator.go +++ /dev/null @@ -1,349 +0,0 @@ -package jwtclaims - -import ( - "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "math/big" - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/golang-jwt/jwt" - log "github.com/sirupsen/logrus" -) - -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool -} - -// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation -type Jwks struct { - Keys []JSONWebKey `json:"keys"` - expiresInTime time.Time -} - -// The supported elliptic curves types -const ( - // p256 represents a cryptographic elliptical curve type. - p256 = "P-256" - - // p384 represents a cryptographic elliptical curve type. - p384 = "P-384" - - // p521 represents a cryptographic elliptical curve type. - p521 = "P-521" -) - -// JSONWebKey is a representation of a Jason Web Key -type JSONWebKey struct { - Kty string `json:"kty"` - Kid string `json:"kid"` - Use string `json:"use"` - N string `json:"n"` - E string `json:"e"` - Crv string `json:"crv"` - X string `json:"x"` - Y string `json:"y"` - X5c []string `json:"x5c"` -} - -type JWTValidator interface { - ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) -} - -// jwtValidatorImpl struct to handle token validation and parsing -type jwtValidatorImpl struct { - options Options -} - -var keyNotFound = errors.New("unable to find appropriate key") - -// NewJWTValidator constructor -func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) { - keys, err := getPemKeys(ctx, keysLocation) - if err != nil { - return nil, err - } - - var lock sync.Mutex - options := Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - // Verify 'aud' claim - var checkAud bool - for _, audience := range audienceList { - checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) - if checkAud { - break - } - } - if !checkAud { - return token, errors.New("invalid audience") - } - // Verify 'issuer' claim - checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(issuer, false) - if !checkIss { - return token, errors.New("invalid issuer") - } - - // If keys are rotated, verify the keys prior to token validation - if idpSignkeyRefreshEnabled { - // If the keys are invalid, retrieve new ones - if !keys.stillValid() { - lock.Lock() - defer lock.Unlock() - - refreshedKeys, err := getPemKeys(ctx, keysLocation) - if err != nil { - log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) - refreshedKeys = keys - } - - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) - - keys = refreshedKeys - } - } - - publicKey, err := getPublicKey(ctx, token, keys) - if err == nil { - return publicKey, nil - } - - msg := fmt.Sprintf("getPublicKey error: %s", err) - if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled { - msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err) - } - - log.WithContext(ctx).Error(msg) - - return nil, err - }, - EnableAuthOnOptions: false, - } - - if options.UserProperty == "" { - options.UserProperty = "user" - } - - return &jwtValidatorImpl{ - options: options, - }, nil -} - -// ValidateAndParse validates the token and returns the parsed token -func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { - // If the token is empty... - if token == "" { - // Check if it was required - if m.options.CredentialsOptional { - log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)") - // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil, nil //nolint:nilnil - } - - // If we get here, the required token is missing - errorMsg := "required authorization token not found" - log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)") - return nil, errors.New(errorMsg) - } - - // Now parse the token - parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter) - - // Check if there was an error in parsing... - if err != nil { - log.WithContext(ctx).Errorf("error parsing token: %v", err) - return nil, fmt.Errorf("error parsing token: %w", err) - } - - // Check if the parsed token is valid... - if !parsedToken.Valid { - errorMsg := "token is invalid" - log.WithContext(ctx).Debug(errorMsg) - return nil, errors.New(errorMsg) - } - - return parsedToken, nil -} - -// stillValid returns true if the JSONWebKey still valid and have enough time to be used -func (jwks *Jwks) stillValid() bool { - return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) -} - -func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) { - resp, err := http.Get(keysLocation) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - jwks := &Jwks{} - err = json.NewDecoder(resp.Body).Decode(jwks) - if err != nil { - return jwks, err - } - - cacheControlHeader := resp.Header.Get("Cache-Control") - expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader) - jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) - - return jwks, err -} - -func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) { - // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time - - for k := range jwks.Keys { - if token.Header["kid"] != jwks.Keys[k].Kid { - continue - } - - if len(jwks.Keys[k].X5c) != 0 { - cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" - return jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) - } - - if jwks.Keys[k].Kty == "RSA" { - log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK") - return getPublicKeyFromRSA(jwks.Keys[k]) - } - if jwks.Keys[k].Kty == "EC" { - log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK") - return getPublicKeyFromECDSA(jwks.Keys[k]) - } - - log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty) - } - - return nil, keyNotFound -} - -func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { - - if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" { - return nil, fmt.Errorf("ecdsa key incomplete") - } - - var xCoordinate []byte - if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil { - return nil, err - } - - var yCoordinate []byte - if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil { - return nil, err - } - - publicKey = &ecdsa.PublicKey{} - - var curve elliptic.Curve - switch jwk.Crv { - case p256: - curve = elliptic.P256() - case p384: - curve = elliptic.P384() - case p521: - curve = elliptic.P521() - } - - publicKey.Curve = curve - publicKey.X = big.NewInt(0).SetBytes(xCoordinate) - publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) - - return publicKey, nil -} - -func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) { - - decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E) - if err != nil { - return nil, err - } - decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return nil, err - } - - var n, e big.Int - e.SetBytes(decodedE) - n.SetBytes(decodedN) - - return &rsa.PublicKey{ - E: int(e.Int64()), - N: &n, - }, nil -} - -// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header -func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { - // Split into individual directives - directives := strings.Split(cacheControl, ",") - - for _, directive := range directives { - directive = strings.TrimSpace(directive) - if strings.HasPrefix(directive, "max-age=") { - // Extract the max-age value - maxAgeStr := strings.TrimPrefix(directive, "max-age=") - maxAge, err := strconv.Atoi(maxAgeStr) - if err != nil { - log.WithContext(ctx).Debugf("error parsing max-age: %v", err) - return 0 - } - - return maxAge - } - } - - return 0 -} - -type JwtValidatorMock struct{} - -func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { - claimMaps := jwt.MapClaims{} - - switch token { - case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": - claimMaps[UserIDClaim] = token - claimMaps[AccountIDSuffix] = "testAccountId" - claimMaps[DomainIDSuffix] = "test.com" - claimMaps[DomainCategorySuffix] = "private" - case "otherUserId": - claimMaps[UserIDClaim] = "otherUserId" - claimMaps[AccountIDSuffix] = "otherAccountId" - claimMaps[DomainIDSuffix] = "other.com" - claimMaps[DomainCategorySuffix] = "private" - case "invalidToken": - return nil, errors.New("invalid token") - } - - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - return jwtToken, nil -} - diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index bcdf75b8cc1..675dd567136 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -440,7 +440,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr, nil) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 43a6e40d502..06f3658303c 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -1,5 +1,7 @@ package server_test +// @todo investigate failures + import ( "context" "math/rand" @@ -513,7 +515,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e1f8e270983..adbe655c489 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -21,6 +22,8 @@ import ( "github.com/netbirdio/netbird/route" ) +var _ server.AccountManager = (*MockAccountManager)(nil) + type MockAccountManager struct { GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error) GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) @@ -54,8 +57,6 @@ type MockAccountManager struct { DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) - GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) - MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) @@ -81,7 +82,6 @@ type MockAccountManager struct { ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) - CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) @@ -238,14 +238,6 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface -func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) { - if am.GetPATInfoFunc != nil { - return am.GetPATInfoFunc(ctx, pat) - } - return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented") -} - // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { if am.DeleteAccountFunc != nil { @@ -254,14 +246,6 @@ func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, user return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented") } -// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface -func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error { - if am.MarkPATUsedFunc != nil { - return am.MarkPATUsedFunc(ctx, pat) - } - return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented") -} - // CreatePAT mock implementation of GetPAT from server.AccountManager interface func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { @@ -620,13 +604,6 @@ func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented") } -func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - if am.CheckUserAccessByJWTGroupsFunc != nil { - return am.CheckUserAccessByJWTGroupsFunc(ctx, claims) - } - return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") -} - // GetPeers mocks GetPeers of the AccountManager interface func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { if am.GetPeersFunc != nil { @@ -849,3 +826,15 @@ func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peer } return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented") } + +func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromUserAuth is not implemented") +} + +func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetUserFromUserAuth is not implemented") +} + +func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { + return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented") +} diff --git a/management/server/user.go b/management/server/user.go index 5efbd2efafe..874c04473dc 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,6 +8,8 @@ import ( "time" "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" @@ -17,7 +19,6 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" - log "github.com/sirupsen/logrus" ) // createServiceUser creates a new service user under the given account. @@ -177,28 +178,39 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { - accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) + userAuth, err := nbContext.GetUserAuthFromContext(ctx) if err != nil { - return nil, fmt.Errorf("failed to get account with token claims %v", err) + return nil, err } + return am.GetUserFromUserAuth(ctx, userAuth) +} - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) +// GetUser looks up a user by provided authorization claims. +// It will also create an account if didn't exist for this user before. +func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { + // @note below is unnecessary, auth middleware not ensures that the account is created + // accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) + // if err != nil { + // return nil, fmt.Errorf("failed to get account with token claims %v", err) + // } + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) if err != nil { return nil, err } // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. - newLogin := user.LastDashboardLoginChanged(claims.LastLogin) + newLogin := user.LastDashboardLoginChanged(userAuth.LastLogin) - err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { - meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta) + meta := map[string]any{"timestamp": userAuth.LastLogin} + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, userAuth.AccountId, activity.DashboardLogin, meta) } return user, nil diff --git a/management/server/user_test.go b/management/server/user_test.go index 5c4b1e2cbef..a2cc53da09a 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,6 +10,8 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" + + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/util" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -24,7 +26,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" - "github.com/netbirdio/netbird/management/server/jwtclaims" ) const ( @@ -921,11 +922,12 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - claims := jwtclaims.AuthorizationClaims{ - UserId: mockUserID, + claims := nbcontext.UserAuth{ + UserId: mockUserID, + AccountId: mockAccountID, } - user, err := am.GetUser(context.Background(), claims) + user, err := am.GetUserFromUserAuth(context.Background(), claims) if err != nil { t.Fatalf("Error when checking user role: %s", err) }