Skip to content

Commit

Permalink
Allow adding ServerInterceptors to specific services and methods (#…
Browse files Browse the repository at this point in the history
…2096)

## Motivation
We want to allow users to customise the RPCs a registered interceptor
should apply to on the server:
- Intercept all requests
- Intercept requests only meant for specific services
- Intercept requests only meant for specific methods

## Modifications
This PR adds a new `ServerInterceptorTarget` type that allows users to
specify what the target of the interceptor should be.
Existing APIs accepting `[any ServerInterceptor]` have been changed to
instead take `[ServerInterceptorTarget]`.

## Result
Users can have more control over to which requests interceptors are
applied.

---------

Co-authored-by: George Barnett <gbarnett@apple.com>
  • Loading branch information
gjcairo and glbrntt authored Nov 13, 2024
1 parent f963523 commit c3f09df
Show file tree
Hide file tree
Showing 9 changed files with 480 additions and 49 deletions.
32 changes: 27 additions & 5 deletions Sources/GRPCCore/Call/Server/RPCRouter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
/// the router has a handler for a method with ``hasHandler(forMethod:)`` or get a list of all
/// methods with handlers registered by calling ``methods``. You can also remove the handler for a
/// given method by calling ``removeHandler(forMethod:)``.
/// You can also register any interceptors that you want applied to registered handlers via the
/// ``registerInterceptors(pipeline:)`` method.
///
/// In most cases you won't need to interact with the router directly. Instead you should register
/// your services with ``GRPCServer/init(transport:services:interceptors:)`` which will in turn
Expand Down Expand Up @@ -82,7 +84,8 @@ public struct RPCRouter: Sendable {
}

@usableFromInline
private(set) var handlers: [MethodDescriptor: RPCHandler]
private(set) var handlers:
[MethodDescriptor: (handler: RPCHandler, interceptors: [any ServerInterceptor])]

/// Creates a new router with no methods registered.
public init() {
Expand Down Expand Up @@ -126,12 +129,13 @@ public struct RPCRouter: Sendable {
_ context: ServerContext
) async throws -> StreamingServerResponse<Output>
) {
self.handlers[descriptor] = RPCHandler(
let handler = RPCHandler(
method: descriptor,
deserializer: deserializer,
serializer: serializer,
handler: handler
)
self.handlers[descriptor] = (handler, [])
}

/// Removes any handler registered for the specified method.
Expand All @@ -142,6 +146,25 @@ public struct RPCRouter: Sendable {
public mutating func removeHandler(forMethod descriptor: MethodDescriptor) -> Bool {
return self.handlers.removeValue(forKey: descriptor) != nil
}

/// Registers applicable interceptors to all currently-registered handlers.
///
/// - Important: Calling this method will apply the interceptors only to existing handlers. Any handlers registered via
/// ``registerHandler(forMethod:deserializer:serializer:handler:)`` _after_ calling this method will not have
/// any interceptors applied to them. If you want to make sure all registered methods have any applicable interceptors applied,
/// only call this method _after_ you have registered all handlers.
/// - Parameter pipeline: The interceptor pipeline operations to register to all currently-registered handlers. The order of the
/// interceptors matters.
/// - SeeAlso: ``ServerInterceptorPipelineOperation``.
@inlinable
public mutating func registerInterceptors(pipeline: [ServerInterceptorPipelineOperation]) {
for descriptor in self.handlers.keys {
let applicableOperations = pipeline.filter { $0.applies(to: descriptor) }
if !applicableOperations.isEmpty {
self.handlers[descriptor]?.interceptors = applicableOperations.map { $0.interceptor }
}
}
}
}

extension RPCRouter {
Expand All @@ -150,10 +173,9 @@ extension RPCRouter {
RPCAsyncSequence<RPCRequestPart, any Error>,
RPCWriter<RPCResponsePart>.Closable
>,
context: ServerContext,
interceptors: [any ServerInterceptor]
context: ServerContext
) async {
if let handler = self.handlers[stream.descriptor] {
if let (handler, interceptors) = self.handlers[stream.descriptor] {
await handler.handle(stream: stream, context: context, interceptors: interceptors)
} else {
// If this throws then the stream must be closed which we can't do anything about, so ignore
Expand Down
17 changes: 9 additions & 8 deletions Sources/GRPCCore/Call/Server/ServerInterceptor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
/// been returned from a service. They are typically used for cross-cutting concerns like filtering
/// requests, validating messages, logging additional data, and tracing.
///
/// Interceptors are registered with the server apply to all RPCs. If you need to modify the
/// behavior of an interceptor on a per-RPC basis then you can use the
/// ``ServerContext/descriptor`` to determine which RPC is being called and
/// conditionalise behavior accordingly.
/// Interceptors can be registered with the server either directly or via ``ServerInterceptorPipelineOperation``s.
/// You may register them for all services registered with a server, for RPCs directed to specific services, or
/// for RPCs directed to specific methods. If you need to modify the behavior of an interceptor on a
/// per-RPC basis in more detail, then you can use the ``ServerContext/descriptor`` to determine
/// which RPC is being called and conditionalise behavior accordingly.
///
/// ## RPC filtering
///
Expand All @@ -33,19 +34,19 @@
/// demonstrates this.
///
/// ```swift
/// struct AuthServerInterceptor: Sendable {
/// struct AuthServerInterceptor: ServerInterceptor {
/// let isAuthorized: @Sendable (String, MethodDescriptor) async throws -> Void
///
/// func intercept<Input: Sendable, Output: Sendable>(
/// request: StreamingServerRequest<Input>,
/// context: ServerInterceptorContext,
/// context: ServerContext,
/// next: @Sendable (
/// _ request: StreamingServerRequest<Input>,
/// _ context: ServerInterceptorContext
/// _ context: ServerContext
/// ) async throws -> StreamingServerResponse<Output>
/// ) async throws -> StreamingServerResponse<Output> {
/// // Extract the auth token.
/// guard let token = request.metadata["authorization"] else {
/// guard let token = request.metadata[stringValues: "authorization"].first(where: { _ in true }) else {
/// throw RPCError(code: .unauthenticated, message: "Not authenticated")
/// }
///
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2024, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/// A `ServerInterceptorPipelineOperation` describes to which RPCs a server interceptor should be applied.
///
/// You can configure a server interceptor to be applied to:
/// - all RPCs and services;
/// - requests directed only to specific services registered with your server; or
/// - requests directed only to specific methods (of a specific service).
///
/// - SeeAlso: ``ServerInterceptor`` for more information on server interceptors.
public struct ServerInterceptorPipelineOperation: Sendable {
/// The subject of a ``ServerInterceptorPipelineOperation``.
/// The subject of an interceptor can either be all services and methods, only specific services, or only specific methods.
public struct Subject: Sendable {
internal enum Wrapped: Sendable {
case all
case services(Set<ServiceDescriptor>)
case methods(Set<MethodDescriptor>)
}

private let wrapped: Wrapped

/// An operation subject specifying an interceptor that applies to all RPCs across all services will be registered with this server.
public static var all: Self { .init(wrapped: .all) }

/// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified services.
/// - Parameters:
/// - services: The list of service names for which this interceptor should intercept RPCs.
/// - Returns: A ``ServerInterceptorPipelineOperation``.
public static func services(_ services: Set<ServiceDescriptor>) -> Self {
Self(wrapped: .services(services))
}

/// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified service methods.
/// - Parameters:
/// - methods: The list of method descriptors for which this interceptor should intercept RPCs.
/// - Returns: A ``ServerInterceptorPipelineOperation``.
public static func methods(_ methods: Set<MethodDescriptor>) -> Self {
Self(wrapped: .methods(methods))
}

@usableFromInline
internal func applies(to descriptor: MethodDescriptor) -> Bool {
switch self.wrapped {
case .all:
return true

case .services(let services):
return services.map({ $0.fullyQualifiedService }).contains(descriptor.service)

case .methods(let methods):
return methods.contains(descriptor)
}
}
}

/// The interceptor specified for this operation.
public let interceptor: any ServerInterceptor

@usableFromInline
internal let subject: Subject

private init(interceptor: any ServerInterceptor, appliesTo: Subject) {
self.interceptor = interceptor
self.subject = appliesTo
}

/// Create an operation, specifying which ``ServerInterceptor`` to apply and to which ``Subject``.
/// - Parameters:
/// - interceptor: The ``ServerInterceptor`` to register with the server.
/// - subject: The ``Subject`` to which the `interceptor` applies.
/// - Returns: A ``ServerInterceptorPipelineOperation``.
public static func apply(_ interceptor: any ServerInterceptor, to subject: Subject) -> Self {
Self(interceptor: interceptor, appliesTo: subject)
}

/// Returns whether this ``ServerInterceptorPipelineOperation`` applies to the given `descriptor`.
/// - Parameter descriptor: A ``MethodDescriptor`` for which to test whether this interceptor applies.
/// - Returns: `true` if this interceptor applies to the given `descriptor`, or `false` otherwise.
@inlinable
internal func applies(to descriptor: MethodDescriptor) -> Bool {
self.subject.applies(to: descriptor)
}
}
47 changes: 26 additions & 21 deletions Sources/GRPCCore/GRPCServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ public final class GRPCServer: Sendable {
/// The services registered which the server is serving.
private let router: RPCRouter

/// A collection of ``ServerInterceptor`` implementations which are applied to all accepted
/// RPCs.
///
/// RPCs are intercepted in the order that interceptors are added. That is, a request received
/// from the client will first be intercepted by the first added interceptor followed by the
/// second, and so on.
private let interceptors: [any ServerInterceptor]

/// The state of the server.
private let state: Mutex<State>

Expand Down Expand Up @@ -154,33 +146,46 @@ public final class GRPCServer: Sendable {
services: [any RegistrableRPCService],
interceptors: [any ServerInterceptor] = []
) {
var router = RPCRouter()
for service in services {
service.registerMethods(with: &router)
}

self.init(transport: transport, router: router, interceptors: interceptors)
self.init(
transport: transport,
services: services,
interceptorPipeline: interceptors.map { .apply($0, to: .all) }
)
}

/// Creates a new server with no resources.
///
/// - Parameters:
/// - transport: The transport the server should listen on.
/// - router: A ``RPCRouter`` used by the server to route accepted streams to method handlers.
/// - interceptors: A collection of interceptors providing cross-cutting functionality to each
/// - services: Services offered by the server.
/// - interceptorPipeline: A collection of interceptors providing cross-cutting functionality to each
/// accepted RPC. The order in which interceptors are added reflects the order in which they
/// are called. The first interceptor added will be the first interceptor to intercept each
/// request. The last interceptor added will be the final interceptor to intercept each
/// request before calling the appropriate handler.
public init(
public convenience init(
transport: any ServerTransport,
router: RPCRouter,
interceptors: [any ServerInterceptor] = []
services: [any RegistrableRPCService],
interceptorPipeline: [ServerInterceptorPipelineOperation]
) {
var router = RPCRouter()
for service in services {
service.registerMethods(with: &router)
}
router.registerInterceptors(pipeline: interceptorPipeline)

self.init(transport: transport, router: router)
}

/// Creates a new server with no resources.
///
/// - Parameters:
/// - transport: The transport the server should listen on.
/// - router: A ``RPCRouter`` used by the server to route accepted streams to method handlers.
public init(transport: any ServerTransport, router: RPCRouter) {
self.state = Mutex(.notStarted)
self.transport = transport
self.router = router
self.interceptors = interceptors
}

/// Starts the server and runs until the registered transport has closed.
Expand All @@ -206,7 +211,7 @@ public final class GRPCServer: Sendable {

do {
try await transport.listen { stream, context in
await self.router.handle(stream: stream, context: context, interceptors: self.interceptors)
await self.router.handle(stream: stream, context: context)
}
} catch {
throw RuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ final class ServerRPCExecutorTests: XCTestCase {

func testThrowingInterceptor() async throws {
let harness = ServerRPCExecutorTestHarness(
interceptors: [.throwError(RPCError(code: .unavailable, message: "Unavailable"))]
interceptors: [
.throwError(RPCError(code: .unavailable, message: "Unavailable"))
]
)

try await harness.execute(handler: .echo) { inbound in
Expand Down
Loading

0 comments on commit c3f09df

Please sign in to comment.