diff --git a/middleware/lushauthmw/grpc.go b/middleware/lushauthmw/grpc.go index dcbd4f3..9d5dabf 100644 --- a/middleware/lushauthmw/grpc.go +++ b/middleware/lushauthmw/grpc.go @@ -103,11 +103,17 @@ func InterceptServerJWT(ctx context.Context, broker CopierRenewer) (lushauth.Con return claims.Consumer, nil } -func handleInterceptError(err error) { +func handleInterceptError(err error) error { switch err { - case ErrMetadataMissing, ErrAuthTokenMissing: + case nil: + return nil + case + ErrMetadataMissing, + ErrAuthTokenMissing: + return nil default: log.Printf("grpc auth middleware error: %v\n", err) + return err } } @@ -115,8 +121,8 @@ func handleInterceptError(err error) { func UnaryServerInterceptor(broker CopierRenewer) func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { consumer, err := InterceptServerJWT(ctx, broker) - if err != nil { - handleInterceptError(err) + if err := handleInterceptError(err); err != nil { + return nil, err } resp, err := handler(lushauth.ContextWithConsumer(ctx, consumer), req) return resp, err @@ -127,8 +133,8 @@ func UnaryServerInterceptor(broker CopierRenewer) func(ctx context.Context, req func StreamServerInterceptor(broker CopierRenewer) func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { consumer, err := InterceptServerJWT(ss.Context(), broker) - if err != nil { - handleInterceptError(err) + if err := handleInterceptError(err); err != nil { + return err } err = handler(srv, &authenticatedServerStream{ss, consumer}) return err diff --git a/middleware/lushauthmw/grpc_test.go b/middleware/lushauthmw/grpc_test.go index dc2a001..aae9098 100644 --- a/middleware/lushauthmw/grpc_test.go +++ b/middleware/lushauthmw/grpc_test.go @@ -11,6 +11,7 @@ import ( "github.com/LUSHDigital/core/test" "github.com/LUSHDigital/core/workers/keybroker/keybrokermock" "github.com/LUSHDigital/uuid" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -49,6 +50,72 @@ func TestGRPCMiddleware(t *testing.T) { ) } +var ( + ok = grpc.UnaryHandler(func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + fail = grpc.UnaryHandler(func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, status.New(codes.Internal, "request failed").Err() + }) +) + +func TestUnaryServerInterceptor(t *testing.T) { + broker := keybrokermock.MockRSAPublicKey(public) + cases := []struct { + name string + tokens []string + handler grpc.UnaryHandler + errors bool + message string + code codes.Code + }{ + { + name: "no token", + tokens: nil, + handler: ok, + code: codes.OK, + errors: false, + }, + { + name: "empty token", + tokens: []string{""}, + handler: ok, + errors: true, + code: codes.InvalidArgument, + message: "token contains an invalid number of segments", + }, + { + name: "malformatted token", + tokens: []string{"abcd123!"}, + handler: ok, + errors: true, + code: codes.InvalidArgument, + message: "token contains an invalid number of segments", + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + md := metadata.MD{} + if c.tokens != nil { + md.Set("auth-token", c.tokens...) + } + ctx := metadata.NewIncomingContext(context.Background(), md) + mw := lushauthmw.UnaryServerInterceptor(broker) + _, err := mw(ctx, nil, nil, ok) + if c.errors { + s, ok := status.FromError(err) + if !ok { + t.Errorf("unknown status from err: %v", err) + } + test.Equals(t, c.message, s.Message()) + test.Equals(t, c.code, s.Code()) + } else { + test.Equals(t, nil, err) + } + }) + } +} + func TestInterceptServerJWT(t *testing.T) { cases := []struct { name string