Skip to content

Commit

Permalink
Merge pull request #9 from LUSHDigital/fix/grpc-auth-mw-errors
Browse files Browse the repository at this point in the history
gRPC auth middleware should throw errors on malformatted tokens
  • Loading branch information
zeevallin authored Mar 30, 2020
2 parents 241573c + c554e66 commit cf1ed42
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 6 deletions.
18 changes: 12 additions & 6 deletions middleware/lushauthmw/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,26 @@ func InterceptServerJWT(ctx context.Context, broker CopierRenewer) (lushauth.Con
return claims.Consumer, nil
}

func handleInterceptError(err error) {
func handleInterceptError(err error) error {
switch err {
case ErrMetadataMissing, ErrAuthTokenMissing:
case nil:
return nil
case
ErrMetadataMissing,
ErrAuthTokenMissing:
return nil
default:
log.Printf("grpc auth middleware error: %v\n", err)
return err
}
}

// UnaryServerInterceptor is a gRPC server-side interceptor that checks that JWT provided is valid for unary procedures
func UnaryServerInterceptor(broker CopierRenewer) func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
consumer, err := InterceptServerJWT(ctx, broker)
if err != nil {
handleInterceptError(err)
if err := handleInterceptError(err); err != nil {
return nil, err
}
resp, err := handler(lushauth.ContextWithConsumer(ctx, consumer), req)
return resp, err
Expand All @@ -127,8 +133,8 @@ func UnaryServerInterceptor(broker CopierRenewer) func(ctx context.Context, req
func StreamServerInterceptor(broker CopierRenewer) func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
consumer, err := InterceptServerJWT(ss.Context(), broker)
if err != nil {
handleInterceptError(err)
if err := handleInterceptError(err); err != nil {
return err
}
err = handler(srv, &authenticatedServerStream{ss, consumer})
return err
Expand Down
67 changes: 67 additions & 0 deletions middleware/lushauthmw/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/LUSHDigital/core/test"
"github.com/LUSHDigital/core/workers/keybroker/keybrokermock"
"github.com/LUSHDigital/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -49,6 +50,72 @@ func TestGRPCMiddleware(t *testing.T) {
)
}

var (
ok = grpc.UnaryHandler(func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
fail = grpc.UnaryHandler(func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, status.New(codes.Internal, "request failed").Err()
})
)

func TestUnaryServerInterceptor(t *testing.T) {
broker := keybrokermock.MockRSAPublicKey(public)
cases := []struct {
name string
tokens []string
handler grpc.UnaryHandler
errors bool
message string
code codes.Code
}{
{
name: "no token",
tokens: nil,
handler: ok,
code: codes.OK,
errors: false,
},
{
name: "empty token",
tokens: []string{""},
handler: ok,
errors: true,
code: codes.InvalidArgument,
message: "token contains an invalid number of segments",
},
{
name: "malformatted token",
tokens: []string{"abcd123!"},
handler: ok,
errors: true,
code: codes.InvalidArgument,
message: "token contains an invalid number of segments",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
md := metadata.MD{}
if c.tokens != nil {
md.Set("auth-token", c.tokens...)
}
ctx := metadata.NewIncomingContext(context.Background(), md)
mw := lushauthmw.UnaryServerInterceptor(broker)
_, err := mw(ctx, nil, nil, ok)
if c.errors {
s, ok := status.FromError(err)
if !ok {
t.Errorf("unknown status from err: %v", err)
}
test.Equals(t, c.message, s.Message())
test.Equals(t, c.code, s.Code())
} else {
test.Equals(t, nil, err)
}
})
}
}

func TestInterceptServerJWT(t *testing.T) {
cases := []struct {
name string
Expand Down

0 comments on commit cf1ed42

Please sign in to comment.