Skip to content

Commit

Permalink
No more struct
Browse files Browse the repository at this point in the history
  • Loading branch information
elffjs committed Aug 3, 2024
1 parent 66eebaa commit af296be
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 63 deletions.
9 changes: 2 additions & 7 deletions cmd/telemetry-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,9 @@ func main() {
IdentityService: idService,
}

tknValidator :=
auth.TokenValidator{
IdentitySvc: idService,
}

cfg := graph.Config{Resolvers: resolver}
cfg.Directives.RequiresVehicleToken = tknValidator.VehicleTokenCheck(settings.VehicleNFTAddress)
cfg.Directives.RequiresManufacturerToken = tknValidator.ManufacturerTokenCheck(settings.VehicleNFTAddress)
cfg.Directives.RequiresVehicleToken = auth.NewVehicleTokenCheck(settings.VehicleNFTAddress)
cfg.Directives.RequiresManufacturerToken = auth.NewManufacturerTokenCheck(settings.VehicleNFTAddress, idService)
cfg.Directives.RequiresPrivileges = auth.PrivilegeCheck
cfg.Directives.IsSignal = noOp
cfg.Directives.HasAggregation = noOp
Expand Down
60 changes: 26 additions & 34 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func newError(msg string, args ...any) error {
return UnauthorizedError{message: fmt.Sprintf(msg, args...)}
}

func (tv *TokenValidator) VehicleTokenCheck(contractAddr string) func(context.Context, any, graphql.Resolver) (any, error) {
func NewVehicleTokenCheck(contractAddr string) func(context.Context, any, graphql.Resolver) (any, error) {
requiredAddr := common.HexToAddress(contractAddr)

return func(ctx context.Context, _ any, next graphql.Resolver) (any, error) {
Expand All @@ -73,17 +73,15 @@ func (tv *TokenValidator) VehicleTokenCheck(contractAddr string) func(context.Co
return nil, UnauthorizedError{err: err}
}

if err := headerTokenMatchesQuery(ctx, requiredAddr, func() (string, error) {
return strconv.Itoa(vehicleTokenID), nil
}); err != nil {
if err := headerTokenMatchesQuery(ctx, requiredAddr, vehicleTokenID); err != nil {
return nil, UnauthorizedError{err: err}
}

return next(ctx)
}
}

func (tv *TokenValidator) ManufacturerTokenCheck(contractAddr string) func(context.Context, any, graphql.Resolver) (any, error) {
func NewManufacturerTokenCheck(contractAddr string, identitySvc IdentityService) func(context.Context, any, graphql.Resolver) (any, error) {
requiredAddr := common.HexToAddress(contractAddr)

return func(ctx context.Context, _ any, next graphql.Resolver) (any, error) {
Expand All @@ -92,37 +90,20 @@ func (tv *TokenValidator) ManufacturerTokenCheck(contractAddr string) func(conte
return nil, fmt.Errorf("unauthorized: %w", err)
}

if err := headerTokenMatchesQuery(ctx, requiredAddr, func() (string, error) {
resp, err := tv.IdentitySvc.GetAftermarketDevice(ctx, adFilter.Address, adFilter.TokenID, adFilter.Serial)
if err != nil {
return "", err
}
return strconv.Itoa(resp.ManufacturerTokenID), nil
}); err != nil {
adResp, err := identitySvc.GetAftermarketDevice(ctx, adFilter.Address, adFilter.TokenID, adFilter.Serial)
if err != nil {
return nil, err
}

if err := headerTokenMatchesQuery(ctx, requiredAddr, adResp.ManufacturerTokenID); err != nil {
return nil, UnauthorizedError{err: err}
}

return next(ctx)
}
}

// PrivilegeCheck checks if the claim set in the context includes the required privileges.
func PrivilegeCheck(ctx context.Context, _ any, next graphql.Resolver, privs []model.Privilege) (any, error) {
claim, err := getTelemetryClaim(ctx)
if err != nil {
return nil, UnauthorizedError{err: err}
}

for _, priv := range privs {
if !claim.privileges.Contains(priv) {
return nil, newError("missing required privilege %s", priv)
}
}

return next(ctx)
}

func headerTokenMatchesQuery(ctx context.Context, requiredAddr common.Address, getTokenStrFromArgs func() (string, error)) error {
func headerTokenMatchesQuery(ctx context.Context, requiredAddr common.Address, tokenID int) error {
claim, err := getTelemetryClaim(ctx)
if err != nil {
return err
Expand All @@ -132,16 +113,27 @@ func headerTokenMatchesQuery(ctx context.Context, requiredAddr common.Address, g
return newError("contract in claim is %s instead of the required %s", claim.ContractAddress, requiredAddr)
}

tknStr, err := getTokenStrFromArgs()
if strconv.Itoa(tokenID) != claim.TokenID {
return fmt.Errorf("token id does not match")
}

return nil
}

// PrivilegeCheck checks if the claim set in the context includes the required privileges.
func PrivilegeCheck(ctx context.Context, _ any, next graphql.Resolver, privs []model.Privilege) (any, error) {
claim, err := getTelemetryClaim(ctx)
if err != nil {
return err
return nil, UnauthorizedError{err: err}
}

if tknStr != claim.TokenID {
return fmt.Errorf("token id does not match")
for _, priv := range privs {
if !claim.privileges.Contains(priv) {
return nil, newError("missing required privilege %s", priv)
}
}

return nil
return next(ctx)
}

func getArg[T any](ctx context.Context, name string) (T, error) {
Expand Down
39 changes: 17 additions & 22 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/DIMO-Network/shared/privileges"
"github.com/DIMO-Network/telemetry-api/internal/graph/model"
"github.com/DIMO-Network/telemetry-api/internal/service/identity"
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -92,7 +91,7 @@ func TestRequiresVehicleTokenCheck(t *testing.T) {
},
}

vehicleCheck := (&TokenValidator{}).VehicleTokenCheck(vehicleNFTAddrRaw)
vehicleCheck := NewVehicleTokenCheck(vehicleNFTAddrRaw)
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -127,9 +126,7 @@ func TestRequiresManufacturerTokenCheck(t *testing.T) {
id.EXPECT().GetAftermarketDevice(gomock.Any(), &invalidAutoPiAddr, gomock.Any(), gomock.Any()).Return(
nil, fmt.Errorf("")).AnyTimes()

tknValidator := &TokenValidator{
IdentitySvc: id,
}
mfrValidator := NewManufacturerTokenCheck(mtrNFTAddrRaw, id)

testCases := []struct {
name string
Expand All @@ -151,7 +148,7 @@ func TestRequiresManufacturerTokenCheck(t *testing.T) {
TokenID: "137",
},
},
tokenValidatorFunc: tknValidator.ManufacturerTokenCheck(mtrNFTAddrRaw),
tokenValidatorFunc: mfrValidator,
},
{
name: "wrong aftermarket device manufacturer",
Expand All @@ -166,8 +163,7 @@ func TestRequiresManufacturerTokenCheck(t *testing.T) {
TokenID: "138",
},
},
tokenValidatorFunc: tknValidator.ManufacturerTokenCheck(mtrNFTAddrRaw),
expectedError: fmt.Errorf("unauthorized: token id does not match"),
expectedError: fmt.Errorf("unauthorized: token id does not match"),
},
{
name: "invalid autopi address",
Expand All @@ -182,8 +178,7 @@ func TestRequiresManufacturerTokenCheck(t *testing.T) {
TokenID: "137",
},
},
tokenValidatorFunc: tknValidator.ManufacturerTokenCheck(mtrNFTAddrRaw),
expectedError: fmt.Errorf("unauthorized: token id does not match"),
expectedError: fmt.Errorf("unauthorized: token id does not match"),
},
}

Expand All @@ -195,7 +190,7 @@ func TestRequiresManufacturerTokenCheck(t *testing.T) {
Args: tc.args,
})
testCtx = context.WithValue(testCtx, TelemetryClaimContextKey{}, tc.telmetryClaim)
result, err := tc.tokenValidatorFunc(testCtx, nil, graphql.Resolver(emptyResolver))
result, err := mfrValidator(testCtx, nil, graphql.Resolver(emptyResolver))
if tc.expectedError != nil {
require.Error(t, err)
return
Expand All @@ -220,7 +215,7 @@ func TestRequiresPrivilegeCheck(t *testing.T) {
name string
privs []model.Privilege
telemetryClaim *TelemetryClaim
expectedError error
expectedError bool
}{
{
name: "valid_privileges",
Expand All @@ -237,6 +232,7 @@ func TestRequiresPrivilegeCheck(t *testing.T) {
ContractAddress: vehicleNFTAddr,
},
},
expectedError: false,
},
{
name: "missing_all_privilege",
Expand All @@ -250,7 +246,7 @@ func TestRequiresPrivilegeCheck(t *testing.T) {
ContractAddress: vehicleNFTAddr,
},
},
expectedError: newError("missing required privilege %s", model.PrivilegeVehicleAllTimeLocation),
expectedError: true,
},
{
name: "missing_one_privilege",
Expand All @@ -266,13 +262,13 @@ func TestRequiresPrivilegeCheck(t *testing.T) {
ContractAddress: vehicleNFTAddr,
},
},
expectedError: newError("missing required privilege %s", model.PrivilegeVehicleNonLocationData),
expectedError: true,
},
{
name: "missing_claim",
privs: []model.Privilege{},
telemetryClaim: nil,
expectedError: UnauthorizedError{err: jwtmiddleware.ErrJWTMissing},
expectedError: true,
},
{
name: "wrong contract for privilege",
Expand All @@ -289,7 +285,7 @@ func TestRequiresPrivilegeCheck(t *testing.T) {
ContractAddress: manufNFTAddr,
},
},
expectedError: newError("missing required privilege %s", model.PrivilegeVehicleAllTimeLocation),
expectedError: true,
},
}

Expand All @@ -302,13 +298,12 @@ func TestRequiresPrivilegeCheck(t *testing.T) {
}
testCtx := context.WithValue(context.Background(), TelemetryClaimContextKey{}, tc.telemetryClaim)
next, err := PrivilegeCheck(testCtx, nil, emptyResolver, tc.privs)
if tc.expectedError != nil {
require.Equal(t, tc.expectedError, err)
return
if tc.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, expectedReturn, next)
}
require.NoError(t, err)
require.Equal(t, expectedReturn, next)

})
}
}

0 comments on commit af296be

Please sign in to comment.