diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 7908963..a0bde18 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -69,14 +69,13 @@ func newError(msg string, args ...any) error { func NewVehicleTokenCheck(contractAddr string) func(context.Context, any, graphql.Resolver) (any, error) { requiredAddr := common.HexToAddress(contractAddr) - return func(ctx context.Context, _ any, next graphql.Resolver) (any, error) { vehicleTokenID, err := getArg[int](ctx, tokenIdArg) if err != nil { return nil, UnauthorizedError{err: err} } - if err := headerTokenMatchesQuery(ctx, requiredAddr, vehicleTokenID); err != nil { + if err := validateHeader(ctx, requiredAddr, vehicleTokenID); err != nil { return nil, UnauthorizedError{err: err} } @@ -86,7 +85,6 @@ func NewVehicleTokenCheck(contractAddr string) func(context.Context, any, graphq func NewManufacturerTokenCheck(contractAddr string, identitySvc IdentityService) func(context.Context, any, graphql.Resolver) (any, error) { requiredAddr := common.HexToAddress(contractAddr) - return func(ctx context.Context, _ any, next graphql.Resolver) (any, error) { adFilter, err := getArg[model.AftermarketDeviceBy](ctx, byArg) if err != nil { @@ -98,7 +96,7 @@ func NewManufacturerTokenCheck(contractAddr string, identitySvc IdentityService) return nil, err } - if err := headerTokenMatchesQuery(ctx, requiredAddr, adResp.ManufacturerTokenID); err != nil { + if err := validateHeader(ctx, requiredAddr, adResp.ManufacturerTokenID); err != nil { return nil, UnauthorizedError{err: err} } @@ -106,7 +104,7 @@ func NewManufacturerTokenCheck(contractAddr string, identitySvc IdentityService) } } -func headerTokenMatchesQuery(ctx context.Context, requiredAddr common.Address, tokenID int) error { +func validateHeader(ctx context.Context, requiredAddr common.Address, tokenID int) error { claim, err := getTelemetryClaim(ctx) if err != nil { return err