From 49e4a0886c2bd20355ba55d48c073c50b6efbabb Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 6 Nov 2024 13:52:43 +0000 Subject: [PATCH] HTTP2 improvements (#601) * Use configureHTTP2AsyncSecureUpgrade * Add HTTP2ServerConnectionManager that records streams being added/removed * HTTP2 connection state machine * Move HTTP2 tests to own target * Handle closing connection * remove immediate close in triggerGracefulShutdown * Fixed HTTP2 channel shutdown * Add HTTP2StreamChannel to handle http2 stream setup * graceful shutdown timeout * Re-order code * Add maxAge timeout, handleInputClosed * Updated comments to answer some PR comments * Set closed state on inputClosed * enhanceYourCalm * Update function header comments * Fix breaking changes * Remove HTTPChannelHandler conformance from HTTP2StreamChannel * minor comment change --- Package.swift | 11 +- .../HummingbirdCore/Request/RequestBody.swift | 15 +- .../Response/ResponseWriter.swift | 4 + .../Server/HTTP/HTTP1Channel.swift | 3 +- .../Server/HTTP/HTTPChannelHandler.swift | 2 +- Sources/HummingbirdHTTP2/HTTP2Channel.swift | 158 +++++++-- ...ServerConnectionManager+StateMachine.swift | 288 ++++++++++++++++ .../HTTP2ServerConnectionManager.swift | 321 ++++++++++++++++++ .../HummingbirdHTTP2/HTTP2StreamChannel.swift | 90 +++++ .../HTTPServerBuilder+http2.swift | 4 +- Tests/HummingbirdCoreTests/HTTP2Tests.swift | 63 ---- .../HummingbirdHTTP2Tests/Certificates.swift | 182 ++++++++++ ...erConnectionManagerStateMachineTests.swift | 105 ++++++ Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift | 188 ++++++++++ Tests/HummingbirdHTTP2Tests/TestUtils.swift | 148 ++++++++ Tests/HummingbirdTests/TracingTests.swift | 2 - 16 files changed, 1469 insertions(+), 115 deletions(-) create mode 100644 Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift create mode 100644 Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift create mode 100644 Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift delete mode 100644 Tests/HummingbirdCoreTests/HTTP2Tests.swift create mode 100644 Tests/HummingbirdHTTP2Tests/Certificates.swift create mode 100644 Tests/HummingbirdHTTP2Tests/HTTP2ServerConnectionManagerStateMachineTests.swift create mode 100644 Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift create mode 100644 Tests/HummingbirdHTTP2Tests/TestUtils.swift diff --git a/Package.swift b/Package.swift index 41fa09a0d..70e192811 100644 --- a/Package.swift +++ b/Package.swift @@ -139,13 +139,22 @@ let package = Package( dependencies: [ .byName(name: "HummingbirdCore"), - .byName(name: "HummingbirdHTTP2"), .byName(name: "HummingbirdTLS"), .byName(name: "HummingbirdTesting"), .product(name: "AsyncHTTPClient", package: "async-http-client"), ], resources: [.process("Certificates")] ), + .testTarget( + name: "HummingbirdHTTP2Tests", + dependencies: + [ + .byName(name: "HummingbirdCore"), + .byName(name: "HummingbirdHTTP2"), + .byName(name: "HummingbirdTesting"), + .product(name: "AsyncHTTPClient", package: "async-http-client"), + ] + ), ], swiftLanguageVersions: [.v5, .version("6")] ) diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index fe1315260..a559ad273 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -195,12 +195,9 @@ extension RequestBody { /// Request body that is a stream of ByteBuffers sourced from a NIOAsyncChannelInboundStream. /// /// This is a unicast async sequence that allows a single iterator to be created. -@usableFromInline -final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { - @usableFromInline - typealias Element = ByteBuffer - @usableFromInline - typealias InboundStream = NIOAsyncChannelInboundStream +public final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { + public typealias Element = ByteBuffer + public typealias InboundStream = NIOAsyncChannelInboundStream @usableFromInline internal let underlyingIterator: UnsafeTransfer.AsyncIterator> @@ -209,7 +206,7 @@ final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { /// Initialize NIOAsyncChannelRequestBody from AsyncIterator of a NIOAsyncChannelInboundStream @inlinable - init(iterator: InboundStream.AsyncIterator) { + public init(iterator: InboundStream.AsyncIterator) { self.underlyingIterator = .init(iterator) self.alreadyIterated = .init(false) } @@ -228,7 +225,7 @@ final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { } @inlinable - mutating func next() async throws -> ByteBuffer? { + public mutating func next() async throws -> ByteBuffer? { if self.done { return nil } // if we are still expecting parts and the iterator finishes. // In this case I think we can just assume we hit an .end @@ -246,7 +243,7 @@ final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence { } @inlinable - func makeAsyncIterator() -> AsyncIterator { + public func makeAsyncIterator() -> AsyncIterator { // verify if an iterator has already been created. If it has then create an // iterator that returns nothing. This could be a precondition failure (currently // an assert) as you should not be allowed to do this. diff --git a/Sources/HummingbirdCore/Response/ResponseWriter.swift b/Sources/HummingbirdCore/Response/ResponseWriter.swift index edd924489..1ee6a1742 100644 --- a/Sources/HummingbirdCore/Response/ResponseWriter.swift +++ b/Sources/HummingbirdCore/Response/ResponseWriter.swift @@ -21,6 +21,10 @@ public struct ResponseWriter: ~Copyable { @usableFromInline let outbound: NIOAsyncChannelOutboundWriter + public init(outbound: NIOAsyncChannelOutboundWriter) { + self.outbound = outbound + } + /// Write HTTP head part and return ``ResponseBodyWriter`` to write response body /// /// - Parameter head: Response head diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift b/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift index 1607876cf..bc6ce2259 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTP1Channel.swift @@ -24,7 +24,7 @@ public struct HTTP1Channel: ServerChildChannel, HTTPChannelHandler { /// HTTP1Channel configuration public struct Configuration: Sendable { - /// Additional channel handlers to add to channel after HTTP part decoding and before HTTP request processing + /// Additional channel handlers to add to channel pipeline after HTTP part decoding and before HTTP request handling public var additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] /// Time before closing an idle channel public var idleTimeout: TimeAmount? @@ -98,6 +98,7 @@ public struct HTTP1Channel: ServerChildChannel, HTTPChannelHandler { /// - Parameters: /// - asyncChannel: NIOAsyncChannel handling HTTP parts /// - logger: Logger to use while processing messages + @inlinable public func handle( value asyncChannel: NIOCore.NIOAsyncChannel, logger: Logging.Logger diff --git a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift index d66a506b4..0b8d8115d 100644 --- a/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift +++ b/Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift @@ -28,7 +28,7 @@ public protocol HTTPChannelHandler: ServerChildChannel { /// Internal error thrown when an unexpected HTTP part is received eg we didn't receive /// a head part when we expected one @usableFromInline -enum HTTPChannelError: Error { +package enum HTTPChannelError: Error { case unexpectedHTTPPart(HTTPRequestPart) } diff --git a/Sources/HummingbirdHTTP2/HTTP2Channel.swift b/Sources/HummingbirdHTTP2/HTTP2Channel.swift index d5d4ed5ca..d75243b24 100644 --- a/Sources/HummingbirdHTTP2/HTTP2Channel.swift +++ b/Sources/HummingbirdHTTP2/HTTP2Channel.swift @@ -2,7 +2,7 @@ // // This source file is part of the Hummingbird server framework project // -// Copyright (c) 2023 the Hummingbird authors +// Copyright (c) 2023-2024 the Hummingbird authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information @@ -16,44 +16,65 @@ import HTTPTypes import HummingbirdCore import Logging import NIOCore +import NIOHTTP1 import NIOHTTP2 import NIOHTTPTypes import NIOHTTPTypesHTTP1 import NIOHTTPTypesHTTP2 import NIOPosix import NIOSSL +import NIOTLS /// Child channel for processing HTTP1 with the option of upgrading to HTTP2 public struct HTTP2UpgradeChannel: HTTPChannelHandler { + typealias HTTP1ConnectionOutput = HTTP1Channel.Value + typealias HTTP2ConnectionOutput = NIOHTTP2Handler.AsyncStreamMultiplexer public struct Value: ServerChildChannelValue { - let negotiatedHTTPVersion: EventLoopFuture, NIOHTTP2Handler.AsyncStreamMultiplexer)>> + let negotiatedHTTPVersion: EventLoopFuture> public let channel: Channel } /// HTTP2 Upgrade configuration public struct Configuration: Sendable { + /// Idle timeout, how long connection is kept idle before closing + public var idleTimeout: Duration? + /// Maximum amount of time to wait for client response before all streams are closed after second GOAWAY has been sent + public var gracefulCloseTimeout: Duration? + /// Maximum amount of time a connection can be open + public var maxAgeTimeout: Duration? /// Configuration applied to HTTP2 stream channels public var streamConfiguration: HTTP1Channel.Configuration /// Initialize HTTP2UpgradeChannel.Configuration /// - Parameters: - /// - additionalChannelHandlers: Additional channel handlers to add to HTTP2 connection channel - /// - streamConfiguration: Configuration applied to HTTP2 stream channels + /// - idleTimeout: How long connection is kept idle before closing + /// - maxGraceCloseTimeout: Maximum amount of time to wait for client response before all streams are closed after second GOAWAY + /// - streamConfiguration: Configuration applieds to HTTP2 stream channels public init( + idleTimeout: Duration? = nil, + gracefulCloseTimeout: Duration? = nil, + maxAgeTimeout: Duration? = nil, streamConfiguration: HTTP1Channel.Configuration = .init() ) { + self.idleTimeout = idleTimeout + self.gracefulCloseTimeout = gracefulCloseTimeout self.streamConfiguration = streamConfiguration } } private let sslContext: NIOSSLContext private let http1: HTTP1Channel - public var responder: HTTPChannelHandler.Responder { self.http1.responder } + private let http2Stream: HTTP2StreamChannel + public let configuration: Configuration + public var responder: Responder { + self.http2Stream.responder + } /// Initialize HTTP2Channel /// - Parameters: /// - tlsConfiguration: TLS configuration - /// - additionalChannelHandlers: Additional channel handlers to add to channel pipeline + /// - additionalChannelHandlers: Additional channel handlers to add to stream channel pipeline after HTTP part decoding and + /// before HTTP request handling /// - responder: Function returning a HTTP response for a HTTP request @available(*, deprecated, renamed: "HTTP1Channel(tlsConfiguration:configuration:responder:)") public init( @@ -64,13 +85,21 @@ public struct HTTP2UpgradeChannel: HTTPChannelHandler { var tlsConfiguration = tlsConfiguration tlsConfiguration.applicationProtocols = NIOHTTP2SupportedALPNProtocols self.sslContext = try NIOSSLContext(configuration: tlsConfiguration) - self.http1 = HTTP1Channel(responder: responder, configuration: .init(additionalChannelHandlers: additionalChannelHandlers())) + self.configuration = .init() + self.http1 = HTTP1Channel( + responder: responder, + configuration: .init(additionalChannelHandlers: additionalChannelHandlers()) + ) + self.http2Stream = HTTP2StreamChannel( + responder: responder, + configuration: .init(additionalChannelHandlers: additionalChannelHandlers()) + ) } /// Initialize HTTP2Channel /// - Parameters: /// - tlsConfiguration: TLS configuration - /// - additionalChannelHandlers: Additional channel handlers to add to channel pipeline + /// - configuration: HTTP2 channel configuration /// - responder: Function returning a HTTP response for a HTTP request public init( tlsConfiguration: TLSConfiguration, @@ -80,7 +109,9 @@ public struct HTTP2UpgradeChannel: HTTPChannelHandler { var tlsConfiguration = tlsConfiguration tlsConfiguration.applicationProtocols = NIOHTTP2SupportedALPNProtocols self.sslContext = try NIOSSLContext(configuration: tlsConfiguration) + self.configuration = configuration self.http1 = HTTP1Channel(responder: responder, configuration: configuration.streamConfiguration) + self.http2Stream = HTTP2StreamChannel(responder: responder, configuration: configuration.streamConfiguration) } /// Setup child channel for HTTP1 with HTTP2 upgrade @@ -95,31 +126,28 @@ public struct HTTP2UpgradeChannel: HTTPChannelHandler { return channel.eventLoop.makeFailedFuture(error) } - return channel.configureAsyncHTTPServerPipeline { http1Channel -> EventLoopFuture in - return http1Channel.eventLoop.makeCompletedFuture { - try http1Channel.pipeline.syncOperations.addHandler(HTTP1ToHTTPServerCodec(secure: true)) - try http1Channel.pipeline.syncOperations.addHandlers(self.http1.configuration.additionalChannelHandlers()) - if let idleTimeout = self.http1.configuration.idleTimeout { - try http1Channel.pipeline.syncOperations.addHandler(IdleStateHandler(readTimeout: idleTimeout)) - } - try http1Channel.pipeline.syncOperations.addHandler(HTTPUserEventHandler(logger: logger)) - return try HTTP1Channel.Value(wrappingChannelSynchronously: http1Channel) - } - } http2ConnectionInitializer: { http2Channel -> EventLoopFuture> in - http2Channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(wrappingChannelSynchronously: http2Channel) - } - } http2StreamInitializer: { http2ChildChannel -> EventLoopFuture in - return http2ChildChannel.eventLoop.makeCompletedFuture { - try http2ChildChannel.pipeline.syncOperations.addHandler(HTTP2FramePayloadToHTTPServerCodec()) - try http2ChildChannel.pipeline.syncOperations.addHandlers(self.http1.configuration.additionalChannelHandlers()) - if let idleTimeout = self.http1.configuration.idleTimeout { - try http2ChildChannel.pipeline.syncOperations.addHandler(IdleStateHandler(readTimeout: idleTimeout)) + return channel.configureHTTP2AsyncSecureUpgrade { channel in + self.http1.setup(channel: channel, logger: logger) + } http2ConnectionInitializer: { channel in + channel.eventLoop.makeCompletedFuture { + let connectionManager = HTTP2ServerConnectionManager( + eventLoop: channel.eventLoop, + idleTimeout: self.configuration.idleTimeout, + maxAgeTimeout: self.configuration.maxAgeTimeout, + gracefulCloseTimeout: self.configuration.gracefulCloseTimeout + ) + let handler: HTTP2ConnectionOutput = try channel.pipeline.syncOperations.configureAsyncHTTP2Pipeline( + mode: .server, + streamDelegate: connectionManager.streamDelegate, + configuration: .init() + ) { http2ChildChannel in + self.http2Stream.setup(channel: http2ChildChannel, logger: logger) } - try http2ChildChannel.pipeline.syncOperations.addHandler(HTTPUserEventHandler(logger: logger)) - return try HTTP1Channel.Value(wrappingChannelSynchronously: http2ChildChannel) + try channel.pipeline.syncOperations.addHandler(connectionManager) + return handler } - }.map { + } + .map { .init(negotiatedHTTPVersion: $0, channel: channel) } } @@ -133,24 +161,80 @@ public struct HTTP2UpgradeChannel: HTTPChannelHandler { let channel = try await value.negotiatedHTTPVersion.get() switch channel { case .http1_1(let http1): - await handleHTTP(asyncChannel: http1, logger: logger) - case .http2((let http2, let multiplexer)): + await self.http1.handle(value: http1, logger: logger) + case .http2(let multiplexer): do { try await withThrowingDiscardingTaskGroup { group in - for try await client in multiplexer.inbound.cancelOnGracefulShutdown() { + for try await client in multiplexer.inbound { group.addTask { - await handleHTTP(asyncChannel: client, logger: logger) + await self.http2Stream.handle(value: client, logger: logger) } } } } catch { logger.error("Error handling inbound connection for HTTP2 handler: \(error)") } - // have to run this to ensure http2 channel outbound writer is closed - try await http2.executeThenClose { _, _ in } } } catch { logger.error("Error getting HTTP2 upgrade negotiated value: \(error)") } } } + +// Code taken from NIOHTTP2 +extension Channel { + /// Configures a channel to perform an HTTP/2 secure upgrade with typed negotiation results. + /// + /// HTTP/2 secure upgrade uses the Application Layer Protocol Negotiation TLS extension to + /// negotiate the inner protocol as part of the TLS handshake. For this reason, until the TLS + /// handshake is complete, the ultimate configuration of the channel pipeline cannot be known. + /// + /// This function configures the channel with a pair of callbacks that will handle the result + /// of the negotiation. It explicitly **does not** configure a TLS handler to actually attempt + /// to negotiate ALPN. The supported ALPN protocols are provided in + /// `NIOHTTP2SupportedALPNProtocols`: please ensure that the TLS handler you are using for your + /// pipeline is appropriately configured to perform this protocol negotiation. + /// + /// If negotiation results in an unexpected protocol, the pipeline will close the connection + /// and no callback will fire. + /// + /// This configuration is acceptable for use on both client and server channel pipelines. + /// + /// - Parameters: + /// - http1ConnectionInitializer: A callback that will be invoked if HTTP/1.1 has been explicitly + /// negotiated, or if no protocol was negotiated. Must return a future that completes when the + /// channel has been fully mutated. + /// - http2ConnectionInitializer: A callback that will be invoked if HTTP/2 has been negotiated, and that + /// should configure the channel for HTTP/2 use. Must return a future that completes when the + /// channel has been fully mutated. + /// - Returns: An `EventLoopFuture` of an `EventLoopFuture` containing the `NIOProtocolNegotiationResult` that completes when the channel + /// is ready to negotiate. + @inlinable + internal func configureHTTP2AsyncSecureUpgrade( + http1ConnectionInitializer: @escaping NIOChannelInitializerWithOutput, + http2ConnectionInitializer: @escaping NIOChannelInitializerWithOutput + ) -> EventLoopFuture>> { + let alpnHandler = NIOTypedApplicationProtocolNegotiationHandler>() { result in + switch result { + case .negotiated("h2"): + // Successful upgrade to HTTP/2. Let the user configure the pipeline. + return http2ConnectionInitializer(self).map { http2Output in .http2(http2Output) } + case .negotiated("http/1.1"), .fallback: + // Explicit or implicit HTTP/1.1 choice. + return http1ConnectionInitializer(self).map { http1Output in .http1_1(http1Output) } + case .negotiated: + // We negotiated something that isn't HTTP/1.1. This is a bad scene, and is a good indication + // of a user configuration error. We're going to close the connection directly. + return self.close().flatMap { self.eventLoop.makeFailedFuture(NIOHTTP2Errors.invalidALPNToken()) } + } + } + + return self.pipeline + .addHandler(alpnHandler) + .flatMap { _ in + self.pipeline.handler(type: NIOTypedApplicationProtocolNegotiationHandler>.self).map { alpnHandler in + alpnHandler.protocolNegotiationResult + } + } + } +} diff --git a/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift new file mode 100644 index 000000000..d827b5b67 --- /dev/null +++ b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager+StateMachine.swift @@ -0,0 +1,288 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOHTTP2 + +extension HTTP2ServerConnectionManager { + struct StateMachine { + var state: State + + init() { + self.state = .active(.init()) + } + + mutating func streamOpened(_ id: HTTP2StreamID) { + switch self.state { + case .active(var activeState): + activeState.openStreams.insert(id) + activeState.lastStreamId = id + self.state = .active(activeState) + + case .closing(var closingState): + closingState.openStreams.insert(id) + closingState.lastStreamId = id + self.state = .closing(closingState) + + case .closed: + break + } + } + + enum StreamClosedResult { + case startIdleTimer + case close + case none + } + + mutating func streamClosed(_ id: HTTP2StreamID) -> StreamClosedResult { + switch self.state { + case .active(var activeState): + activeState.openStreams.remove(id) + self.state = .active(activeState) + if activeState.openStreams.isEmpty { + return .startIdleTimer + } else { + return .none + } + + case .closing(var closingState): + closingState.openStreams.remove(id) + if closingState.openStreams.isEmpty, closingState.sentSecondGoAway == true { + self.state = .closed + return .close + } else { + self.state = .closing(closingState) + return .none + } + + case .closed: + return .none + } + } + + enum TriggerGracefulShutdownResult { + case sendGoAway(pingData: HTTP2PingData) + case none + } + + mutating func triggerGracefulShutdown() -> TriggerGracefulShutdownResult { + switch self.state { + case .active(let activeState): + let closingState = State.ClosingState(from: activeState) + self.state = .closing(closingState) + return .sendGoAway(pingData: closingState.goAwayPingData) + + case .closing: + return .none + + case .closed: + return .none + } + } + + enum ReceivedPingResult { + case sendPingAck(pingData: HTTP2PingData) + case enhanceYourCalmAndClose(lastStreamId: HTTP2StreamID) // Sent when client sends too many pings + case none + } + + mutating func receivedPing(atTime time: NIODeadline, data: HTTP2PingData) -> ReceivedPingResult { + switch self.state { + case .active(var activeState): + let tooManyPings = activeState.keepalive.receivedPing(atTime: time, hasOpenStreams: activeState.openStreams.count > 0) + if tooManyPings { + self.state = .closed + return .enhanceYourCalmAndClose(lastStreamId: activeState.lastStreamId) + } else { + self.state = .active(activeState) + return .sendPingAck(pingData: data) + } + + case .closing(var closingState): + let tooManyPings = closingState.keepalive.receivedPing(atTime: time, hasOpenStreams: closingState.openStreams.count > 0) + if tooManyPings { + self.state = .closed + return .enhanceYourCalmAndClose(lastStreamId: closingState.lastStreamId) + } else { + self.state = .closing(closingState) + return .sendPingAck(pingData: data) + } + + case .closed: + return .none + } + } + + enum ReceivedPingAckResult { + case sendGoAway(lastStreamId: HTTP2StreamID, close: Bool) + case none + } + + mutating func receivedPingAck(data: HTTP2PingData) -> ReceivedPingAckResult { + switch self.state { + case .active: + return .none + + case .closing(var state): + guard state.goAwayPingData == data else { + return .none + } + state.sentSecondGoAway = true + if state.openStreams.count > 0 { + self.state = .closing(state) + return .sendGoAway(lastStreamId: state.lastStreamId, close: false) + } else { + self.state = .closed + return .sendGoAway(lastStreamId: state.lastStreamId, close: true) + } + + case .closed: + return .none + } + } + + enum InputClosedResult { + case closeWithGoAway(lastStreamId: HTTP2StreamID) + case close + case none + } + + mutating func inputClosed() -> InputClosedResult { + switch self.state { + case .active(let activeState): + self.state = .closed + return .closeWithGoAway(lastStreamId: activeState.lastStreamId) + + case .closing(let closeState): + if closeState.sentSecondGoAway { + self.state = .closed + return .close + } else { + return .closeWithGoAway(lastStreamId: closeState.lastStreamId) + } + + case .closed: + return .none + } + } + } +} + +extension HTTP2ServerConnectionManager.StateMachine { + enum State { + struct ActiveState { + var openStreams: Set + var lastStreamId: HTTP2StreamID + var keepalive: Keepalive + + init() { + self.openStreams = .init() + self.lastStreamId = .rootStream + self.keepalive = .init(allowWithoutCalls: true, minPingReceiveIntervalWithoutCalls: .seconds(30)) + } + } + + struct ClosingState { + var openStreams: Set + var lastStreamId: HTTP2StreamID + var keepalive: Keepalive + var sentSecondGoAway: Bool + let goAwayPingData: HTTP2PingData + + init(from activeState: ActiveState) { + self.openStreams = activeState.openStreams + self.lastStreamId = activeState.lastStreamId + self.keepalive = activeState.keepalive + self.sentSecondGoAway = false + self.goAwayPingData = HTTP2PingData(withInteger: .random(in: .min ... .max)) + } + } + + case active(ActiveState) + case closing(ClosingState) + case closed + } +} + +extension HTTP2ServerConnectionManager.StateMachine { + struct Keepalive { + /// Allow the client to send keep alive pings when there are no active calls. + private let allowWithoutCalls: Bool + + /// The minimum time interval which pings may be received at when there are no active calls. + private let minPingReceiveIntervalWithoutCalls: TimeAmount + + /// The maximum number of "bad" pings sent by the client the server tolerates before closing + /// the connection. + private let maxPingStrikes: Int + + /// The number of "bad" pings sent by the client. This can be reset when the server sends + /// DATA or HEADERS frames. + /// + /// Ping strikes account for pings being occasionally being used for purposes other than keep + /// alive (a low number of strikes is therefore expected and okay). + private var pingStrikes: Int + + /// The last time a valid ping happened. + /// + /// Note: `distantPast` isn't used to indicate no previous valid ping as `NIODeadline` uses + /// the monotonic clock on Linux which uses an undefined starting point and in some cases isn't + /// always that distant. + private var lastValidPingTime: NIODeadline? + + init(allowWithoutCalls: Bool, minPingReceiveIntervalWithoutCalls: TimeAmount) { + self.allowWithoutCalls = allowWithoutCalls + self.minPingReceiveIntervalWithoutCalls = minPingReceiveIntervalWithoutCalls + self.maxPingStrikes = 2 + self.pingStrikes = 0 + self.lastValidPingTime = nil + } + + /// Reset ping strikes and the time of the last valid ping. + mutating func reset() { + self.lastValidPingTime = nil + self.pingStrikes = 0 + } + + /// Returns whether the client has sent too many pings. + mutating func receivedPing(atTime time: NIODeadline, hasOpenStreams: Bool) -> Bool { + let interval: TimeAmount + + if hasOpenStreams || self.allowWithoutCalls { + interval = self.minPingReceiveIntervalWithoutCalls + } else { + // If there are no open streams and keep alive pings aren't allowed without calls then + // use an interval of two hours. + // + // This comes from gRFC A8: https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md + interval = .hours(2) + } + + // If there's no last ping time then the first is acceptable. + let isAcceptablePing = self.lastValidPingTime.map { $0 + interval <= time } ?? true + let tooManyPings: Bool + + if isAcceptablePing { + self.lastValidPingTime = time + tooManyPings = false + } else { + self.pingStrikes += 1 + tooManyPings = self.pingStrikes > self.maxPingStrikes + } + + return tooManyPings + } + } +} diff --git a/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift new file mode 100644 index 000000000..a4f63bfff --- /dev/null +++ b/Sources/HummingbirdHTTP2/HTTP2ServerConnectionManager.swift @@ -0,0 +1,321 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOHTTP2 + +/// HTTP2 server connection manager +/// +/// This is heavily based off the ServerConnectionManagementHandler from https://github.com/grpc/grpc-swift-nio-transport +final class HTTP2ServerConnectionManager: ChannelDuplexHandler { + package typealias InboundIn = HTTP2Frame + package typealias InboundOut = HTTP2Frame + package typealias OutboundIn = HTTP2Frame + package typealias OutboundOut = HTTP2Frame + + /// HTTP2ServerConnectionManager state + var state: StateMachine + /// Idle timer + var idleTimer: Timer? + /// Maximum time a connection be open timer + var maxAgeTimer: Timer? + /// Maximum amount of time we wait before closing the connection + var gracefulCloseTimer: Timer? + /// EventLoop connection manager running on + var eventLoop: EventLoop + /// Channel handler context + var channelHandlerContext: ChannelHandlerContext? + /// Are we reading + var inReadLoop: Bool + /// flush pending when read completes + var flushPending: Bool + + init( + eventLoop: EventLoop, + idleTimeout: Duration?, + maxAgeTimeout: Duration?, + gracefulCloseTimeout: Duration? + ) { + self.eventLoop = eventLoop + self.state = .init() + self.inReadLoop = false + self.flushPending = false + self.idleTimer = idleTimeout.map { Timer(delay: .init($0)) } + self.maxAgeTimer = maxAgeTimeout.map { Timer(delay: .init($0)) } + self.gracefulCloseTimer = gracefulCloseTimeout.map { Timer(delay: .init($0)) } + } + + func handlerAdded(context: ChannelHandlerContext) { + self.channelHandlerContext = context + let loopBoundHandler = LoopBoundHandler(self) + self.idleTimer?.schedule(on: self.eventLoop) { + loopBoundHandler.triggerGracefulShutdown() + } + self.maxAgeTimer?.schedule(on: self.eventLoop) { + loopBoundHandler.triggerGracefulShutdown() + } + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.idleTimer?.cancel() + self.gracefulCloseTimer?.cancel() + self.channelHandlerContext = nil + } + + func channelActive(context: ChannelHandlerContext) { + context.fireChannelActive() + } + + func channelInactive(context: ChannelHandlerContext) { + context.fireChannelInactive() + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.inReadLoop = true + + let frame = self.unwrapInboundIn(data) + switch frame.payload { + case .ping(let data, let ack): + if ack { + self.handlePingAck(context: context, data: data) + } else { + self.handlePing(context: context, data: data) + } + + default: + break // Only interested in PING frames, ignore the rest. + } + + context.fireChannelRead(data) + } + + func channelReadComplete(context: ChannelHandlerContext) { + self.inReadLoop = false + if self.flushPending { + context.flush() + self.flushPending = false + } + context.fireChannelReadComplete() + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case is ChannelShouldQuiesceEvent: + self.triggerGracefulShutdown(context: context) + case let channelEvent as ChannelEvent where channelEvent == .inputClosed: + self.handleInputClosed(context: context) + default: + break + } + context.fireUserInboundEventTriggered(event) + } + + func errorCaught(context: ChannelHandlerContext, error: any Error) { + context.close(mode: .all, promise: nil) + } + + func optionallyFlush(context: ChannelHandlerContext) { + if self.inReadLoop { + self.flushPending = true + } else { + context.flush() + } + } + + func handlePing(context: ChannelHandlerContext, data: HTTP2PingData) { + switch self.state.receivedPing(atTime: .now(), data: data) { + case .sendPingAck: + break // ping acks are sent by NIOHTTP2 channel handler + + case .enhanceYourCalmAndClose(let lastStreamId): + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: lastStreamId, + errorCode: .enhanceYourCalm, + opaqueData: context.channel.allocator.buffer(string: "too_many_pings") + ) + ) + + context.write(self.wrapOutboundOut(goAway), promise: nil) + self.optionallyFlush(context: context) + context.close(promise: nil) + + case .none: + break + } + } + + func handlePingAck(context: ChannelHandlerContext, data: HTTP2PingData) { + switch self.state.receivedPingAck(data: data) { + case .sendGoAway(let lastStreamId, let close): + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: lastStreamId, + errorCode: .noError, + opaqueData: nil + ) + ) + context.write(self.wrapOutboundOut(goAway), promise: nil) + self.optionallyFlush(context: context) + + if close { + context.close(promise: nil) + } else { + // Setup grace period for closing. Close the connection abruptly once the grace period passes. + let loopBound = NIOLoopBound(context, eventLoop: context.eventLoop) + self.gracefulCloseTimer?.schedule(on: context.eventLoop) { + loopBound.value.close(promise: nil) + } + } + case .none: + break + } + } + + func triggerGracefulShutdown(context: ChannelHandlerContext) { + switch self.state.triggerGracefulShutdown() { + case .sendGoAway(let pingData): + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: .maxID, + errorCode: .noError, + opaqueData: nil + ) + ) + let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(pingData, ack: false)) + context.write(self.wrapOutboundOut(goAway), promise: nil) + context.write(self.wrapOutboundOut(ping), promise: nil) + self.optionallyFlush(context: context) + + case .none: + break + } + } + + func handleInputClosed(context: ChannelHandlerContext) { + switch self.state.inputClosed() { + case .closeWithGoAway(let lastStreamId): + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: lastStreamId, + errorCode: .connectError, + opaqueData: context.channel.allocator.buffer(string: "input_closed") + ) + ) + + context.write(self.wrapOutboundOut(goAway), promise: nil) + self.optionallyFlush(context: context) + context.close(promise: nil) + + case .close: + context.close(promise: nil) + + case .none: + break + } + } +} + +extension HTTP2ServerConnectionManager { + struct LoopBoundHandler: @unchecked Sendable { + let handler: HTTP2ServerConnectionManager + init(_ handler: HTTP2ServerConnectionManager) { + self.handler = handler + } + + func triggerGracefulShutdown() { + self.handler.eventLoop.preconditionInEventLoop() + guard let context = self.handler.channelHandlerContext else { return } + self.handler.triggerGracefulShutdown(context: context) + } + } +} + +extension HTTP2ServerConnectionManager { + /// Stream delegate + struct HTTP2StreamDelegate: NIOHTTP2StreamDelegate, @unchecked Sendable { + let handler: HTTP2ServerConnectionManager + + /// A new HTTP/2 stream was created with the given ID. + func streamCreated(_ id: HTTP2StreamID, channel: Channel) { + if self.handler.eventLoop.inEventLoop { + self.handler._streamCreated(id, channel: channel) + } else { + self.handler.eventLoop.execute { + self.handler._streamCreated(id, channel: channel) + } + } + } + + /// An HTTP/2 stream with the given ID was closed. + func streamClosed(_ id: HTTP2StreamID, channel: Channel) { + if self.handler.eventLoop.inEventLoop { + self.handler._streamClosed(id, channel: channel) + } else { + self.handler.eventLoop.execute { + self.handler._streamClosed(id, channel: channel) + } + } + } + } + + var streamDelegate: HTTP2StreamDelegate { + .init(handler: self) + } + + /// A new HTTP/2 stream was created with the given ID. + func _streamCreated(_ id: HTTP2StreamID, channel: Channel) { + self.state.streamOpened(id) + self.idleTimer?.cancel() + } + + /// An HTTP/2 stream with the given ID was closed. + func _streamClosed(_ id: HTTP2StreamID, channel: Channel) { + switch self.state.streamClosed(id) { + case .startIdleTimer: + let loopBoundHandler = LoopBoundHandler(self) + self.idleTimer?.schedule(on: self.eventLoop) { + loopBoundHandler.triggerGracefulShutdown() + } + case .close: + LoopBoundHandler(self).triggerGracefulShutdown() + case .none: + break + } + } +} + +struct Timer { + var scheduled: Scheduled? + let delay: TimeAmount + + init(delay: TimeAmount) { + self.delay = delay + self.scheduled = nil + } + + mutating func schedule(on eventLoop: EventLoop, _ task: @escaping @Sendable () throws -> Void) { + self.cancel() + self.scheduled = eventLoop.scheduleTask(in: self.delay, task) + } + + mutating func cancel() { + self.scheduled?.cancel() + self.scheduled = nil + } +} diff --git a/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift new file mode 100644 index 000000000..af385b9b0 --- /dev/null +++ b/Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes +import HummingbirdCore +import Logging +import NIOCore +import NIOHTTPTypes +import NIOHTTPTypesHTTP2 + +/// HTTP2 Child channel for processing an HTTP2 stream +struct HTTP2StreamChannel: ServerChildChannel { + typealias Value = NIOAsyncChannel + typealias Configuration = HTTP1Channel.Configuration + + /// Initialize HTTP2StreamChannel + /// - Parameters: + /// - responder: Function returning a HTTP response for a HTTP request + /// - configuration: HTTP2 stream channel configuration + init( + responder: @escaping HTTPChannelHandler.Responder, + configuration: Configuration = .init() + ) { + self.configuration = configuration + self.responder = responder + } + + /// Setup child channel for HTTP2 stream + /// - Parameters: + /// - channel: Child channel + /// - logger: Logger used during setup + /// - Returns: Object to process input/output on child channel + func setup(channel: Channel, logger: Logger) -> EventLoopFuture { + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(HTTP2FramePayloadToHTTPServerCodec()) + try channel.pipeline.syncOperations.addHandlers(self.configuration.additionalChannelHandlers()) + if let idleTimeout = self.configuration.idleTimeout { + try channel.pipeline.syncOperations.addHandler(IdleStateHandler(readTimeout: idleTimeout)) + } + try channel.pipeline.syncOperations.addHandler(HTTPUserEventHandler(logger: logger)) + return try HTTP1Channel.Value(wrappingChannelSynchronously: channel) + } + } + + /// handle single HTTP request/response + /// - Parameters: + /// - asyncChannel: NIOAsyncChannel handling HTTP parts + /// - logger: Logger to use while processing messages + func handle( + value asyncChannel: NIOCore.NIOAsyncChannel, + logger: Logging.Logger + ) async { + do { + try await withTaskCancellationHandler { + try await asyncChannel.executeThenClose { inbound, outbound in + var iterator = inbound.makeAsyncIterator() + + // read first part, verify it is a head + guard let part = try await iterator.next() else { return } + guard case .head(let head) = part else { + throw HTTPChannelError.unexpectedHTTPPart(part) + } + let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator) + let request = Request(head: head, body: .init(asyncSequence: bodyStream)) + let responseWriter = ResponseWriter(outbound: outbound) + try await self.responder(request, responseWriter, asyncChannel.channel) + } + } onCancel: { + asyncChannel.channel.close(mode: .input, promise: nil) + } + } catch { + // we got here because we failed to either read or write to the channel + logger.trace("Failed to read/write to Channel. Error: \(error)") + } + } + + let responder: HTTPChannelHandler.Responder + let configuration: Configuration +} diff --git a/Sources/HummingbirdHTTP2/HTTPServerBuilder+http2.swift b/Sources/HummingbirdHTTP2/HTTPServerBuilder+http2.swift index df69a1578..b1c174a62 100644 --- a/Sources/HummingbirdHTTP2/HTTPServerBuilder+http2.swift +++ b/Sources/HummingbirdHTTP2/HTTPServerBuilder+http2.swift @@ -28,7 +28,8 @@ extension HTTPServerBuilder { /// ``` /// - Parameters: /// - tlsConfiguration: TLS configuration - /// - additionalChannelHandlers: Additional channel handlers to call before handling HTTP + /// - additionalChannelHandlers: Additional channel handlers to add to stream channel pipeline after HTTP part decoding and + /// before HTTP request handling /// - Returns: HTTPChannelHandler builder @available(*, deprecated, renamed: "http2Upgrade(tlsConfiguration:configuration:)") public static func http2Upgrade( @@ -54,6 +55,7 @@ extension HTTPServerBuilder { /// ) /// ``` /// - Parameters: + /// - tlsConfiguration: TLS configuration /// - configuration: HTTP2 Upgrade channel configuration /// - Returns: HTTPChannelHandler builder public static func http2Upgrade( diff --git a/Tests/HummingbirdCoreTests/HTTP2Tests.swift b/Tests/HummingbirdCoreTests/HTTP2Tests.swift deleted file mode 100644 index ff9e30a5c..000000000 --- a/Tests/HummingbirdCoreTests/HTTP2Tests.swift +++ /dev/null @@ -1,63 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Hummingbird server framework project -// -// Copyright (c) 2023 the Hummingbird authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import AsyncHTTPClient -import HummingbirdCore -import HummingbirdHTTP2 -import HummingbirdTesting -import Logging -import NIOCore -import NIOHTTP1 -import NIOPosix -import NIOSSL -import NIOTransportServices -import XCTest - -final class HummingBirdHTTP2Tests: XCTestCase { - func testConnect() async throws { - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) - defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } - var logger = Logger(label: "Hummingbird") - logger.logLevel = .trace - try await testServer( - responder: { (_, responseWriter: consuming ResponseWriter, _) in - try await responseWriter.writeResponse(.init(status: .ok)) - }, - httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), - configuration: .init(address: .hostname(port: 0), serverName: testServerName), - eventLoopGroup: eventLoopGroup, - logger: logger - ) { port in - var tlsConfiguration = try getClientTLSConfiguration() - // no way to override the SSL server name with AsyncHTTPClient so need to set - // hostname verification off - tlsConfiguration.certificateVerification = .noHostnameVerification - let httpClient = HTTPClient( - eventLoopGroupProvider: .shared(eventLoopGroup), - configuration: .init(tlsConfiguration: tlsConfiguration) - ) - - let response: HTTPClientResponse - do { - let request = HTTPClientRequest(url: "https://localhost:\(port)/") - response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) - } catch { - try? await httpClient.shutdown() - throw error - } - try await httpClient.shutdown() - XCTAssertEqual(response.status, .ok) - } - } -} diff --git a/Tests/HummingbirdHTTP2Tests/Certificates.swift b/Tests/HummingbirdHTTP2Tests/Certificates.swift new file mode 100644 index 000000000..72c022d6c --- /dev/null +++ b/Tests/HummingbirdHTTP2Tests/Certificates.swift @@ -0,0 +1,182 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird open source project +// +// Copyright (c) YEARS the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOSSL + +let testServerName = "hummingbird.codes" + +let caCertificateData = + """ + -----BEGIN CERTIFICATE----- + MIIDyTCCArGgAwIBAgIUMlPXRgNMa+eUbn/hsCK88Zm1FI8wDQYJKoZIhvcNAQEL + BQAwdDELMAkGA1UEBhMCVUsxEjAQBgNVBAgMCUVkaW5idXJnaDESMBAGA1UEBwwJ + RWRpbmJ1cmdoMRQwEgYDVQQKDAtIdW1taW5nYmlyZDELMAkGA1UECwwCQ0ExGjAY + BgNVBAMMEWh1bW1pbmdiaXJkLmNvZGVzMB4XDTI0MDEyNzE1NDc0MloXDTI1MDEy + NjE1NDc0MlowdDELMAkGA1UEBhMCVUsxEjAQBgNVBAgMCUVkaW5idXJnaDESMBAG + A1UEBwwJRWRpbmJ1cmdoMRQwEgYDVQQKDAtIdW1taW5nYmlyZDELMAkGA1UECwwC + Q0ExGjAYBgNVBAMMEWh1bW1pbmdiaXJkLmNvZGVzMIIBIjANBgkqhkiG9w0BAQEF + AAOCAQ8AMIIBCgKCAQEAoSMlfkfyINkI63a0q5KpMjtulVb9/MESJtaiZeG0HNMj + pVGJ5c9p/Ypzp7qodgoX/6vEQahLqdfyw0dB9MzA5hOuKrLDTXhnBFiyOBrrzYLH + CBYwhJiGVPaG8HUof/UfZwYmK7NpK+g3oSyl7PKbiWTQTq+Z3uOmV7FGD1XSTSks + cU2ARsJROxWz2sTFGwqc7I4Qa8XuIIhRhLVJinagKnGnv6dyTNwFO6fl4oU0Ils9 + V19jIrBZ6cDRLTPsqMuIxjqk6YQNZ+W7CmrgT6MEceigidyBRJi7Q5iz7FniXurz + +T3lMXBaZFVFv1E3P5j4FTfBVt9n7yo07fp/QoVd3QIDAQABo1MwUTAdBgNVHQ4E + FgQU19wgSafcyM6xz0CJt+0IePdAj24wHwYDVR0jBBgwFoAU19wgSafcyM6xz0CJ + t+0IePdAj24wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEALb9o + kmr6qHQFvgaM7Cv4ETZ67rsZ7PlG3uTH1m3zqjJhMDYaDGHcUXOioSkfwON2+tYK + crs6IjuE9XqiZBoszBqUeSze67/095xysUTC0JyljE259PSb2Woal3g/zOh1d3Dm + SKFNZDmkp/coRkz9UlJNbafwmYFzaMl0nVkIf3LKFj8gBd1qW+H+2uSQAZFIWCDY + MJoLF8vhJR0W5/vO3axmASFyAwSiP0NlIVC3HE0rNziE2CMBs5aXkUcikZKvC+q+ + TRLXM40Ead4Ne1aJb4aABzscrzApfa1ZRfF9CuVawqp1pYn/XJS/WCSlMyiJe8ms + oJasPVFZ9xo0TPg5uw== + -----END CERTIFICATE----- + """ + +let serverCertificateData = + """ + -----BEGIN CERTIFICATE----- + MIIECDCCAvCgAwIBAgIUZ3cPKQJZL0/i8e3twD3UNRQnJfUwDQYJKoZIhvcNAQEL + BQAwdDELMAkGA1UEBhMCVUsxEjAQBgNVBAgMCUVkaW5idXJnaDESMBAGA1UEBwwJ + RWRpbmJ1cmdoMRQwEgYDVQQKDAtIdW1taW5nYmlyZDELMAkGA1UECwwCQ0ExGjAY + BgNVBAMMEWh1bW1pbmdiaXJkLmNvZGVzMB4XDTI0MDEyNzE1NDc0MloXDTI1MDEy + NjE1NDc0MloweDELMAkGA1UEBhMCVUsxEjAQBgNVBAgMCUVkaW5idXJnaDESMBAG + A1UEBwwJRWRpbmJ1cmdoMRQwEgYDVQQKDAtIdW1taW5nYmlyZDEPMA0GA1UECwwG + U2VydmVyMRowGAYDVQQDDBFodW1taW5nYmlyZC5jb2RlczCCASIwDQYJKoZIhvcN + AQEBBQADggEPADCCAQoCggEBAKX+mOG6fZko5yv3OrOrHBuBWE+dchwezM5hi7xp + Zyja/dDhhO6IZBkJtmR9Uw11+ZAxWao2yVIkpT+0jehhDzGRFn88+CrKPR2/r5eh + Bmv4dUQNxnJPjvMzx9QgcjSJf6uxTFNngJID0BmA5UeJ2Xi+/WsX8zELm+CD7e7V + 1gfcCTLY5Y12dfHd0J1ZbTxp7k3XpadXLdhZq0lLjYIwdLbmZxtOgqXirwCRU4SR + bvJLEwnMcnJvEIg9Q4zXf4aWM45BAUYz9rMr8WlLKl31j6fQP/TcZ4QNVrVVpKfC + Ok0w8b9BEebvMNhStgndJ4sn5oBpZEA40kbCcdr0d8rM6wsCAwEAAaOBjTCBijAL + BgNVHQ8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMBwGA1Ud + EQQVMBOCEWh1bW1pbmdiaXJkLmNvZGVzMB0GA1UdDgQWBBSyCStEARj5dJucaIGj + WlJYkEOeBTAfBgNVHSMEGDAWgBTX3CBJp9zIzrHPQIm37Qh490CPbjANBgkqhkiG + 9w0BAQsFAAOCAQEASb4IHtnr1GcbgpyX/6rjoeZ1s56O1mG3bv4c91dV4ca0nr7r + UxbgUkqBSf88fpgd82Dr/AcU4XmD/W1b5J8P/+RZiIH4+ztuN1MWiWiRduEbN3Vo + 2hfTcCQFTcvO36nkqy/vFUgKwAUS7/Qm5pNoThf7paWSvOdcPg3zZhjU2qzIb9KR + SlXZ0YooUc7uQ6lFmgmgZEZ2bKykKue2TfXRPLI86yXv2dVzShMvv+njCbxkWBEh + Mra6nBnNkdj9PoB2eKZV3VvWgGrSVher8JVDW7bN1dJ94ppugO6Pnwy06fbLo7+h + ijBsqIWiDQNOQQrPx1iCTbdtg5UOKNIFwWynVQ== + -----END CERTIFICATE----- + """ + +let serverPrivateKeyData = + """ + -----BEGIN PRIVATE KEY----- + MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCl/pjhun2ZKOcr + 9zqzqxwbgVhPnXIcHszOYYu8aWco2v3Q4YTuiGQZCbZkfVMNdfmQMVmqNslSJKU/ + tI3oYQ8xkRZ/PPgqyj0dv6+XoQZr+HVEDcZyT47zM8fUIHI0iX+rsUxTZ4CSA9AZ + gOVHidl4vv1rF/MxC5vgg+3u1dYH3Aky2OWNdnXx3dCdWW08ae5N16WnVy3YWatJ + S42CMHS25mcbToKl4q8AkVOEkW7ySxMJzHJybxCIPUOM13+GljOOQQFGM/azK/Fp + Sypd9Y+n0D/03GeEDVa1VaSnwjpNMPG/QRHm7zDYUrYJ3SeLJ+aAaWRAONJGwnHa + 9HfKzOsLAgMBAAECggEAA/RGyP2Mw6+cG5Ev4yZhg2GEd9tnmHoCaVKGQHUZ2w+H + uBeFUfdz6RiLT4NkSa4JTRAna9R8KKapLcokO/d7dtIa6p5wBGIbhtXkP8Eklciu + GusHNJRBEOqJNnDhTxy6Odo59g7nBDkUcOy8LyMkGtdOHeRjdQffpYqztx5ukwlR + OP3Br47APSlhB/LoJxfg6tk59tqidx+cwxZeu3HaW4s46/mHWw5FqW74PAt53XN9 + aSQdWscXhuutRtPbM+GY9d6Lf1pp9/Cq+XCAV3iRgjfFQPgD2INc0ySTY10uGOWc + YQpa4UBX1ZPRAPa0YBwCwF4Xz/4BBfDzfiCrosungQKBgQDXuMFEcuQFsjLYm3Vz + W9N4UZg3HeuABH29Dt09yyFlg6hZD3FFFBBWDPipT6aFdmewYDTyTZ5RPBfBLH4r + MTIYuIje/jF9bxGkKzK0Pos5nsBeJDxLf3CcJvALEc2u7l0spS5y2MJ9RU6PMaF9 + 0DoABf7yNmYbyHtyRB5a52fBEQKBgQDE/PD0TinVgWsjxkogv/rQcLn0aFVufoT7 + j01VPloATzntYdmvVT/EiAnTf4IJx8p0TqvFnvlYQj7Tffzv1B8iXmx6wAmWeS94 + ElnWi/5pkIy1XI1DUy6MCdOcOCSwL36tcKEzFjYrVJtwGiR8gxz3PQuxquFmoofQ + jYy1V1aqWwKBgBKZ0LhpO7YuBmpdBUScL2DZkEl4X/0a5giuRm90m32YW6TKSxcM + wtfYqHxY7N/nNMulkAswnC0fBGFYx8xLoqk1CEBKJNRPBnNkcivOlMy0Hpw/fZ94 + 7qnYRax+rYCe9xPJbnbir+qDVmHMgsNJeCbWXYRfInDU2aghrYhjGbQxAoGBAMRG + k3+Zci1+alaW+L1xDGQsLdzNKHKUNcTBoHhTTDIKvtk8Kj59XrBgLApEfjlojN0e + liCuqhu6xgbM/f2pCeyg0M3uEp+P2DB3eHRBwRlGIi2DLm3qr/JwyBxcBJJYgIwo + MTZJ52d9QfOM2NYHfhELDl/UuAof39t5br4xa/UJAoGBAKUYB48+AK9oFPj6YG8U + N7XHfElJPggmygAUtndYwJeNRGbFew+P4fhrRG+2nhIVjVOBSwAvTnGPexwbkWCz + NOAf/sV8p4YEBa+KwkN/dSXUD+LE+s3ARmLzFaEkYyl03U9bJ9RfCz2wLCpTmjq0 + 50ies8PMrKKJxjkhVycT3pFJ + -----END PRIVATE KEY----- + """ + +let clientCertificateData = + """ + -----BEGIN CERTIFICATE----- + MIIDvDCCAqSgAwIBAgIUZ3cPKQJZL0/i8e3twD3UNRQnJfYwDQYJKoZIhvcNAQEL + BQAwdDELMAkGA1UEBhMCVUsxEjAQBgNVBAgMCUVkaW5idXJnaDESMBAGA1UEBwwJ + RWRpbmJ1cmdoMRQwEgYDVQQKDAtIdW1taW5nYmlyZDELMAkGA1UECwwCQ0ExGjAY + BgNVBAMMEWh1bW1pbmdiaXJkLmNvZGVzMB4XDTI0MDEyNzE1NDc0MloXDTI1MDEy + NjE1NDc0MloweDELMAkGA1UEBhMCVUsxEjAQBgNVBAgMCUVkaW5idXJnaDESMBAG + A1UEBwwJRWRpbmJ1cmdoMRQwEgYDVQQKDAtIdW1taW5nYmlyZDEPMA0GA1UECwwG + Q2xpZW50MRowGAYDVQQDDBFodW1taW5nYmlyZC5jb2RlczCCASIwDQYJKoZIhvcN + AQEBBQADggEPADCCAQoCggEBAKWNbU5Xk/FBhHdVu1CPuQJGwxqTOggJq/7tp5Wu + HR9aMpgb/zWuEaT/eL5tZJKYX8Y2MY8/AOkoVE0fjB8sK8nwG4CgGrrxBV7MsSQJ + 43PqQE4WXxC2bZbn5dLIr6ABZ4nTvuQvq8Pv/ylp/7Pek6aFEM8APIac0lAFcJzn + OArC2x7jUap53cgHP64xiO+ZF2tT88CGVNEBYCWAZ6x1Eaz0PbKm/wWc5pIGbgW+ + i4lP69bkfzXczLjN3xce61Jyx9Kj6DeUqIPR2YQwYHORnEpwDCrlhL1o6NGDzM/j + 2/t9IzMnjIeoNGOZtrbx1QhjH6Hu4waRhkck30my+ukYLpsCAwEAAaNCMEAwHQYD + VR0OBBYEFL8Uh8IaSnv66cS3mHy4rE1RHdm9MB8GA1UdIwQYMBaAFNfcIEmn3MjO + sc9AibftCHj3QI9uMA0GCSqGSIb3DQEBCwUAA4IBAQAGG8Fv4eTFT8UaNZkuhnMA + BT2+he8O0xlvFXse+QpL451ISU1KjSbh/N2jDfpob3/nO1EKYEuG5XKHmhlTjrzb + sa0YW5ad31jPgCExm69WRVfJOlnVL1olbzmyibGbQ8lFax0QgYO9rLhvkJocQs2D + tJX0xNL/2BccaVQvj7i8qAHeiQ9NqO46g4Uob5jE2nswJLZh9REddNsFWKxxL8jK + k+Ez6oW1s6QUaOoOm3Dh94fuYD34hgDeDIu+ec7bOiIIwKAholKQuoHqphMbZvQ5 + QWv2gB3vE9Ep1VKrVr9dT4NST80Bmw7E1piUuqBShsohLc0GEkSrWboGP8vWbu+F + -----END CERTIFICATE----- + """ + +let clientPrivateKeyData = + """ + -----BEGIN PRIVATE KEY----- + MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCljW1OV5PxQYR3 + VbtQj7kCRsMakzoICav+7aeVrh0fWjKYG/81rhGk/3i+bWSSmF/GNjGPPwDpKFRN + H4wfLCvJ8BuAoBq68QVezLEkCeNz6kBOFl8Qtm2W5+XSyK+gAWeJ077kL6vD7/8p + af+z3pOmhRDPADyGnNJQBXCc5zgKwtse41Gqed3IBz+uMYjvmRdrU/PAhlTRAWAl + gGesdRGs9D2ypv8FnOaSBm4FvouJT+vW5H813My4zd8XHutScsfSo+g3lKiD0dmE + MGBzkZxKcAwq5YS9aOjRg8zP49v7fSMzJ4yHqDRjmba28dUIYx+h7uMGkYZHJN9J + svrpGC6bAgMBAAECggEACW02pdXXYjVK78KaPyzLEF9rUszCt3XZhqANdIWGbTEY + UJ0tGIDk/bV547OPg2HMXkx0R1+DU6nMtw5OgiRK1mpNUcy/PBkz2mFWATyg6D17 + IOPynZ1NoZPQ/DVNYfm1snbnCs/RSRkvn2UrC380GBcoM4+kL3DbI6kgb7nvJBZu + p5ftCeUjSOJWi5ImmaPFvBsF24bxCAuwk0Gw8q9ybqpJHLm8ybkXpiF3SvXlnKGt + RLxKhAVSOKbyrWQNUv9RDx2xAfibpqUAo3gZVyxkDY2Gkb1J5YT27bhFricJNaVz + FFxhC3O+X3ihMBBNnq1VwwjoeSzWmwS1BFPgVFzLbQKBgQDWTnT2eFrzV7ADXCMT + 1bg4hoFJD/QUJXqozAvCIAQx6xSamoadUHKMzYpGvI/5YEYbn3sgoRiQDrZd5jZf + zWRJuyQxdq1bakBsx4vji3TJ1eN0ovzTQHAB1Z+5tAw6CN9htUTyW/NbvzS4/eMd + 9we68ye2gHrgFVtfVC6emIriNQKBgQDFwsYb5xRKce5F/iL4o11LfA0Dyu9Vekvg + FPBXdE6pSzZimeyC3Y8u144eWiXTfo7DT8nY1b5JTXmhUH9Q84lTkELq666rTn9N + KV3LIMEweHX//GBi+unZC5K6H8dnc9YzBsL4P6SO/ZqHzef02EGadGEyg4rlSkpp + yqJ+SI7njwKBgBer7OF4o9szOV71o25CcinUOZ2fZH+BME5K05Wqwavd4pW9MddY + ln6VCYwMsf6CstvEPu54vOTUqzIuBp2Ia2Z1hGbuS/HIB7u8QuhsdAcDWC9+/Vw8 + RuL8/Lqfd6ZFap85TZdTrsrYkPNKH/ckXTc6Oo2/HVN5KHGcM9YS1WxtAoGAKMnE + bIrbn4MiHuOMuPWQz3nVgVvAw0OHFL+c1pzRgI9XtzyCEHe8CXBCCraTKKzoqxXw + zr0/EwVcuc3NhJfGUirl8mgLzZ9SGEsY4kVuMx4VUGfwRVn1E2QUrjjRut+kZT/W + xLbzrN5Xmfz5A4H6/e1VAsMoyaPp9ynpG9zBRLcCgYEAj/J9KsG6gqECwK12dcqz + brMAb7X3v05Kk22Nskhis6p31AOgg67MI8y3ANko2LADOHfov1HNaTwkCdhAaFoZ + 1mJhowXVjxJJA4QWzPYGQSrVfKrUGJf8y5vHos5NQWF2VYNVJsSP5D17MoMwagqW + kPQvfvMHrv6al2joWL+8/3U= + -----END PRIVATE KEY----- + """ + +func getServerTLSConfiguration() throws -> TLSConfiguration { + let caCertificate = try NIOSSLCertificate(bytes: [UInt8](caCertificateData.utf8), format: .pem) + let certificate = try NIOSSLCertificate(bytes: [UInt8](serverCertificateData.utf8), format: .pem) + let privateKey = try NIOSSLPrivateKey(bytes: [UInt8](serverPrivateKeyData.utf8), format: .pem) + var tlsConfig = TLSConfiguration.makeServerConfiguration(certificateChain: [.certificate(certificate)], privateKey: .privateKey(privateKey)) + tlsConfig.trustRoots = .certificates([caCertificate]) + return tlsConfig +} + +func getClientTLSConfiguration() throws -> TLSConfiguration { + let caCertificate = try NIOSSLCertificate(bytes: [UInt8](caCertificateData.utf8), format: .pem) + let certificate = try NIOSSLCertificate(bytes: [UInt8](clientCertificateData.utf8), format: .pem) + let privateKey = try NIOSSLPrivateKey(bytes: [UInt8](clientPrivateKeyData.utf8), format: .pem) + var tlsConfig = TLSConfiguration.makeClientConfiguration() + tlsConfig.trustRoots = .certificates([caCertificate]) + tlsConfig.certificateChain = [.certificate(certificate)] + tlsConfig.privateKey = .privateKey(privateKey) + return tlsConfig +} diff --git a/Tests/HummingbirdHTTP2Tests/HTTP2ServerConnectionManagerStateMachineTests.swift b/Tests/HummingbirdHTTP2Tests/HTTP2ServerConnectionManagerStateMachineTests.swift new file mode 100644 index 000000000..38e1d3f7c --- /dev/null +++ b/Tests/HummingbirdHTTP2Tests/HTTP2ServerConnectionManagerStateMachineTests.swift @@ -0,0 +1,105 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import HummingbirdHTTP2 +import NIOCore +import NIOHTTP2 +import XCTest + +final class HTTP2ServerConnectionManagerStateMachineTests: XCTestCase { + func testAddRemoveClose() { + var stateMachine = HTTP2ServerConnectionManager.StateMachine() + stateMachine.streamOpened(.init(2)) + XCTAssertEqual(stateMachine.streamClosed(.init(2)), .startIdleTimer) + let triggerGracefulShutdownResult = stateMachine.triggerGracefulShutdown() + guard case .sendGoAway(let pingData) = triggerGracefulShutdownResult else { XCTFail(); return } + let pingAckResult = stateMachine.receivedPingAck(data: pingData) + guard case .sendGoAway(let lastStreamId, let close) = pingAckResult else { XCTFail(); return } + XCTAssertEqual(close, true) + XCTAssertEqual(lastStreamId, 2) + guard case .closed = stateMachine.state else { XCTFail(); return } + } + + func testAddCloseRemove() { + var stateMachine = HTTP2ServerConnectionManager.StateMachine() + stateMachine.streamOpened(.init(2)) + let triggerGracefulShutdownResult = stateMachine.triggerGracefulShutdown() + guard case .sendGoAway(let pingData) = triggerGracefulShutdownResult else { XCTFail(); return } + let pingAckResult = stateMachine.receivedPingAck(data: pingData) + guard case .sendGoAway(let lastStreamId, let close) = pingAckResult else { XCTFail(); return } + XCTAssertEqual(close, false) + XCTAssertEqual(lastStreamId, 2) + XCTAssertEqual(stateMachine.streamClosed(.init(2)), .close) + guard case .closed = stateMachine.state else { XCTFail(); return } + } + + func testCloseAddRemove() { + var stateMachine = HTTP2ServerConnectionManager.StateMachine() + let triggerGracefulShutdownResult = stateMachine.triggerGracefulShutdown() + guard case .sendGoAway(let pingData) = triggerGracefulShutdownResult else { XCTFail(); return } + stateMachine.streamOpened(.init(2)) + let pingAckResult = stateMachine.receivedPingAck(data: pingData) + guard case .sendGoAway(let lastStreamId, let close) = pingAckResult else { XCTFail(); return } + XCTAssertEqual(close, false) + XCTAssertEqual(lastStreamId, 2) + XCTAssertEqual(stateMachine.streamClosed(.init(2)), .close) + guard case .closed = stateMachine.state else { XCTFail(); return } + } + + func testReceivedPing() { + let now = NIODeadline.now() + let pingData = HTTP2PingData(withInteger: .random(in: .min ... .max)) + var stateMachine = HTTP2ServerConnectionManager.StateMachine() + stateMachine.streamOpened(.init(4)) + var pingResult = stateMachine.receivedPing(atTime: now, data: pingData) + guard case .sendPingAck(let data) = pingResult else { XCTFail(); return } + XCTAssertEqual(data, pingData) + pingResult = stateMachine.receivedPing(atTime: now + .seconds(1), data: pingData) + guard case .sendPingAck = pingResult else { XCTFail(); return } + pingResult = stateMachine.receivedPing(atTime: now + .seconds(1), data: pingData) + guard case .sendPingAck = pingResult else { XCTFail(); return } + pingResult = stateMachine.receivedPing(atTime: now + .seconds(2), data: pingData) + guard case .enhanceYourCalmAndClose(let id) = pingResult else { XCTFail(); return } + XCTAssertEqual(id, 4) + guard case .closed = stateMachine.state else { XCTFail(); return } + } + + func testClosedState() { + // get statemachine into closed state + var stateMachine = HTTP2ServerConnectionManager.StateMachine() + let triggerGracefulShutdownResult = stateMachine.triggerGracefulShutdown() + guard case .sendGoAway(let pingData) = triggerGracefulShutdownResult else { XCTFail(); return } + let pingAckResult = stateMachine.receivedPingAck(data: pingData) + guard case .sendGoAway(_, let close) = pingAckResult else { XCTFail(); return } + XCTAssertEqual(close, true) + + // test closed state responses + XCTAssertEqual(stateMachine.streamClosed(.init(0)), .none) + guard case .none = stateMachine.receivedPing(atTime: .now(), data: .init()) else { XCTFail(); return } + guard case .none = stateMachine.receivedPingAck(data: .init()) else { XCTFail(); return } + guard case .none = stateMachine.triggerGracefulShutdown() else { XCTFail(); return } + } + + func testClosePingAckWrongData() { + let randomPingData = HTTP2PingData(withInteger: .random(in: .min ... .max)) + // get statemachine into closed state + var stateMachine = HTTP2ServerConnectionManager.StateMachine() + let triggerGracefulShutdownResult = stateMachine.triggerGracefulShutdown() + guard case .sendGoAway(let pingData) = triggerGracefulShutdownResult else { XCTFail(); return } + var pingAckResult = stateMachine.receivedPingAck(data: randomPingData) + guard case .none = pingAckResult else { XCTFail(); return } + pingAckResult = stateMachine.receivedPingAck(data: pingData) + guard case .sendGoAway = pingAckResult else { XCTFail(); return } + } +} diff --git a/Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift b/Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift new file mode 100644 index 000000000..91bf96c72 --- /dev/null +++ b/Tests/HummingbirdHTTP2Tests/HTTP2Tests.swift @@ -0,0 +1,188 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import HummingbirdCore +import HummingbirdHTTP2 +import HummingbirdTesting +import Logging +import NIOCore +import NIOHTTP1 +import NIOHTTPTypes +import NIOPosix +import NIOSSL +import NIOTransportServices +import XCTest + +final class HummingBirdHTTP2Tests: XCTestCase { + func testConnect() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { + XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) + } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace + + var tlsConfiguration = try getClientTLSConfiguration() + // no way to override the SSL server name with AsyncHTTPClient so need to set + // hostname verification off + tlsConfiguration.certificateVerification = .noHostnameVerification + try await withHTTPClient(.init(tlsConfiguration: tlsConfiguration)) { httpClient in + try await testServer( + responder: { (_, responseWriter: consuming ResponseWriter, _) in + try await responseWriter.writeResponse(.init(status: .ok)) + }, + httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), + configuration: .init(address: .hostname(port: 0), serverName: testServerName), + eventLoopGroup: eventLoopGroup, + logger: logger + ) { port in + let request = HTTPClientRequest(url: "https://localhost:\(port)/") + let response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) + XCTAssertEqual(response.status, .ok) + } + } + } + + func testMultipleSerialRequests() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace + + var tlsConfiguration = try getClientTLSConfiguration() + // no way to override the SSL server name with AsyncHTTPClient so need to set + // hostname verification off + tlsConfiguration.certificateVerification = .noHostnameVerification + try await withHTTPClient(.init(tlsConfiguration: tlsConfiguration)) { httpClient in + try await testServer( + responder: { (_, responseWriter: consuming ResponseWriter, _) in + try await responseWriter.writeResponse(.init(status: .ok)) + }, + httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), + configuration: .init(address: .hostname(port: 0), serverName: testServerName), + eventLoopGroup: eventLoopGroup, + logger: logger + ) { port in + let request = HTTPClientRequest(url: "https://localhost:\(port)/") + for _ in 0..<16 { + let response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) + _ = try await response.body.collect(upTo: .max) + XCTAssertEqual(response.status, .ok) + } + } + } + } + + func testMultipleConcurrentRequests() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace + + var tlsConfiguration = try getClientTLSConfiguration() + // no way to override the SSL server name with AsyncHTTPClient so need to set + // hostname verification off + tlsConfiguration.certificateVerification = .noHostnameVerification + try await withHTTPClient(.init(tlsConfiguration: tlsConfiguration)) { httpClient in + try await testServer( + responder: { (_, responseWriter: consuming ResponseWriter, _) in + try await responseWriter.writeResponse(.init(status: .ok)) + }, + httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), + configuration: .init(address: .hostname(port: 0), serverName: testServerName), + eventLoopGroup: eventLoopGroup, + logger: logger + ) { port in + let request = HTTPClientRequest(url: "https://localhost:\(port)/") + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + for _ in 0..<16 { + let response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) + _ = try await response.body.collect(upTo: .max) + XCTAssertEqual(response.status, .ok) + } + } + try await group.waitForAll() + } + } + } + } + + func testConnectionClosed() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace + + try await testServer( + responder: { (_, responseWriter: consuming ResponseWriter, _) in + try await responseWriter.writeResponse(.init(status: .ok)) + }, + httpChannelSetup: .http2Upgrade( + tlsConfiguration: getServerTLSConfiguration() + ), + configuration: .init(address: .hostname(port: 0), serverName: testServerName), + eventLoopGroup: eventLoopGroup, + logger: logger + ) { port in + var tlsConfiguration = try getClientTLSConfiguration() + // no way to override the SSL server name with AsyncHTTPClient so need to set + // hostname verification off + tlsConfiguration.certificateVerification = .noHostnameVerification + try await withHTTPClient(.init(tlsConfiguration: tlsConfiguration)) { httpClient in + let request = HTTPClientRequest(url: "https://localhost:\(port)/") + let response = try await httpClient.execute(request, deadline: .now() + .seconds(30)) + XCTAssertEqual(response.status, .ok) + } + } + } + + func testHTTP1Connect() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + var logger = Logger(label: "Hummingbird") + logger.logLevel = .trace + try await testServer( + responder: { (_, responseWriter: consuming ResponseWriter, _) in + try await responseWriter.writeResponse(.init(status: .ok)) + }, + httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()), + configuration: .init(address: .hostname(port: 0), serverName: testServerName), + eventLoopGroup: eventLoopGroup, + logger: logger + ) { port in + var tlsConfiguration = try getClientTLSConfiguration() + // no way to override the SSL server name with AsyncHTTPClient so need to set + // hostname verification off + tlsConfiguration.certificateVerification = .noHostnameVerification + let client = TestClient( + host: "localhost", + port: port, + configuration: .init(tlsConfiguration: tlsConfiguration), + eventLoopGroupProvider: .shared(eventLoopGroup) + ) + client.connect() + let response: TestClient.Response + do { + response = try await client.get("/") + } catch { + try? await client.shutdown() + throw error + } + try await client.shutdown() + XCTAssertEqual(response.status, .ok) + } + } +} diff --git a/Tests/HummingbirdHTTP2Tests/TestUtils.swift b/Tests/HummingbirdHTTP2Tests/TestUtils.swift new file mode 100644 index 000000000..0a269e6a2 --- /dev/null +++ b/Tests/HummingbirdHTTP2Tests/TestUtils.swift @@ -0,0 +1,148 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import HummingbirdCore +import HummingbirdTesting +import Logging +import NIOCore +import NIOPosix +import NIOSSL +import ServiceLifecycle +import XCTest + +public enum TestErrors: Error { + case timeout +} + +/// Basic responder that just returns "Hello" in body +@Sendable public func helloResponder(to request: Request, responseWriter: consuming ResponseWriter, channel: Channel) async throws { + let responseBody = channel.allocator.buffer(string: "Hello") + var bodyWriter = try await responseWriter.writeHead(.init(status: .ok)) + try await bodyWriter.write(responseBody) + try await bodyWriter.finish(nil) +} + +/// Helper function for testing a server +public func testServer( + responder: @escaping HTTPChannelHandler.Responder, + httpChannelSetup: HTTPServerBuilder, + configuration: ServerConfiguration, + eventLoopGroup: EventLoopGroup, + logger: Logger, + _ test: @escaping @Sendable (Int) async throws -> Value +) async throws -> Value { + try await withThrowingTaskGroup(of: Void.self) { group in + let promise = Promise() + let server = try httpChannelSetup.buildServer( + configuration: configuration, + eventLoopGroup: eventLoopGroup, + logger: logger, + responder: responder, + onServerRunning: { await promise.complete($0.localAddress!.port!) } + ) + let serviceGroup = ServiceGroup( + configuration: .init( + services: [server], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: logger + ) + ) + + group.addTask { + try await serviceGroup.run() + } + let value = try await test(promise.wait()) + await serviceGroup.triggerGracefulShutdown() + return value + } +} + +func withHTTPClient( + _ configuration: HTTPClient.Configuration, + eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, + _ process: (HTTPClient) async throws -> Value +) async throws -> Value { + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(eventLoopGroup), + configuration: configuration + ) + let value: Value + do { + value = try await process(httpClient) + } catch { + try? await httpClient.shutdown() + throw error + } + try await httpClient.shutdown() + return value +} + +/// Run process with a timeout +/// - Parameters: +/// - timeout: Amount of time before timeout error is thrown +/// - process: Process to run +public func withTimeout(_ timeout: TimeAmount, _ process: @escaping @Sendable () async throws -> Void) async throws { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await Task.sleep(nanoseconds: numericCast(timeout.nanoseconds)) + throw TestErrors.timeout + } + group.addTask { + try await process() + } + try await group.next() + group.cancelAll() + } +} + +/// Promise type. +actor Promise { + enum State { + case blocked([CheckedContinuation]) + case unblocked(Value) + } + + var state: State + + init() { + self.state = .blocked([]) + } + + /// wait from promise to be completed + func wait() async -> Value { + switch self.state { + case .blocked(var continuations): + return await withCheckedContinuation { cont in + continuations.append(cont) + self.state = .blocked(continuations) + } + case .unblocked(let value): + return value + } + } + + /// complete promise with value + func complete(_ value: Value) { + switch self.state { + case .blocked(let continuations): + for cont in continuations { + cont.resume(returning: value) + } + self.state = .unblocked(value) + case .unblocked: + break + } + } +} diff --git a/Tests/HummingbirdTests/TracingTests.swift b/Tests/HummingbirdTests/TracingTests.swift index 67562d46e..f4075f212 100644 --- a/Tests/HummingbirdTests/TracingTests.swift +++ b/Tests/HummingbirdTests/TracingTests.swift @@ -568,9 +568,7 @@ final class TracingTests: XCTestCase { XCTAssertEqual(span2.context.testID, "testMiddleware") } -} -extension TracingTests { /// Test tracing middleware serviceContext is propagated to async route handlers func testServiceContextPropagationAsync() async throws { let expectation = expectation(description: "Expected span to be ended.")