Skip to content

Commit

Permalink
Add more auth error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinJoiner committed Apr 25, 2024
1 parent 3daf74f commit 520cebc
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions cmd/telemetry-api/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
Expand All @@ -19,6 +20,8 @@ import (
"github.com/rs/zerolog"
)

var errUnauthorized = fmt.Errorf("unauthorized")

// CustomContextKey is a custom key for the context to store the custom claims.
type CustomContextKey struct{}

Expand Down Expand Up @@ -93,7 +96,11 @@ func AddClaimHandler(next http.Handler, logger *zerolog.Logger) http.Handler {
func ErrorHandler(logger *zerolog.Logger) func(w http.ResponseWriter, r *http.Request, err error) {
return func(w http.ResponseWriter, r *http.Request, err error) {
logger.Error().Err(err).Msg("error validating token")
jwtmiddleware.DefaultErrorHandler(w, r, err)
if errors.Is(err, errUnauthorized) {
jwtmiddleware.DefaultErrorHandler(w, r, errUnauthorized)
} else {
jwtmiddleware.DefaultErrorHandler(w, r, err)
}
}
}

Expand All @@ -111,7 +118,7 @@ type customClaimWrapper struct {
// Validate function is required to implement the validator.CustomClaims interface.
func (c *customClaimWrapper) Validate(context.Context) error {
if c.expectedContractAddress != c.CustomClaims.ContractAddress {
return fmt.Errorf("incorrect contract address expected %v got %v", c.expectedContractAddress, c.CustomClaims.ContractAddress)
return fmt.Errorf("%w: incorrect contract address expected %v got %v", errUnauthorized, c.expectedContractAddress, c.CustomClaims.ContractAddress)
}
return nil
}
Expand All @@ -120,7 +127,7 @@ func requiresPrivilegeCheck(ctx context.Context, obj interface{}, next graphql.R
claim := getClaim(ctx)
for _, priv := range privileges {
if _, ok := claim.privileges[priv]; !ok {
return nil, fmt.Errorf("unathorized")
return nil, fmt.Errorf("%w: missing required privileges", errUnauthorized)
}
}
return next(ctx)
Expand All @@ -137,16 +144,16 @@ var privToAPI = map[privileges.Privilege]model.Privilege{
func requiresTokenCheck(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) {
fCtx := graphql.GetFieldContext(ctx)
if fCtx == nil {
return nil, fmt.Errorf("no field context found")
return nil, fmt.Errorf("%w: no field context found", err)
}
tokenID, ok := fCtx.Args["tokenID"].(int)
if !ok {
return nil, fmt.Errorf("failed to get tokenID from args")
return nil, fmt.Errorf("%w: failed to get tokenID from args", err)
}

claim := getClaim(ctx)
if strconv.Itoa(tokenID) != claim.TokenID {
return nil, fmt.Errorf("unathorized")
return nil, fmt.Errorf("%w: tokenID mismatch", err)
}
return next(ctx)
}

0 comments on commit 520cebc

Please sign in to comment.