diff --git a/Sources/KituraNet/HTTP/HTTPRequestHandler.swift b/Sources/KituraNet/HTTP/HTTPRequestHandler.swift index a57f2629..2cfba67f 100644 --- a/Sources/KituraNet/HTTP/HTTPRequestHandler.swift +++ b/Sources/KituraNet/HTTP/HTTPRequestHandler.swift @@ -79,6 +79,22 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle if errorResponseSent { return } switch request { case .head(let header): + _ = server.connectionCount.add(1) + if let connectionLimit = server.options.connectionLimit, + server.connectionCount.load() > connectionLimit { + // Reaching connection limit: closing now. + do { + if let (httpStatus, response) = server.options.connectionResponseGenerator(connectionLimit,serverRequest?.remoteAddress ?? "") { + serverResponse = HTTPServerResponse(channel: context.channel, handler: self) + errorResponseSent = true + try serverResponse?.end(with: httpStatus, message: response) + } + } catch { + Log.error("Failed to send error response") + } + context.close(promise: nil) + return + } serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification) if let requestSizeLimit = server.options.requestSizeLimit, let contentLength = header.headers["Content-Length"].first, @@ -93,7 +109,7 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle } catch { Log.error("Failed to send error response") } - context.close() + context.close(promise: nil) } } serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification) @@ -125,20 +141,6 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle case .end: requestSize = 0 - server.connectionCount.add(1) - if let connectionLimit = server.options.connectionLimit { - if server.connectionCount.load() > connectionLimit { - do { - if let (httpStatus, response) = server.options.connectionResponseGenerator(connectionLimit,serverRequest?.remoteAddress ?? "") { - serverResponse = HTTPServerResponse(channel: context.channel, handler: self) - errorResponseSent = true - try serverResponse?.end(with: httpStatus, message: response) - } - } catch { - Log.error("Failed to send error response") - } - } - } serverResponse = HTTPServerResponse(channel: context.channel, handler: self) //Make sure we use the latest delegate registered with the server DispatchQueue.global().async { @@ -201,6 +203,6 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle } func channelInactive(context: ChannelHandlerContext) { - server.connectionCount.sub(1) + _ = server.connectionCount.sub(1) } }