diff --git a/libp2p/protocols/connectivity/relay/rtransport.nim b/libp2p/protocols/connectivity/relay/rtransport.nim index 007d462267..3d49f039c5 100644 --- a/libp2p/protocols/connectivity/relay/rtransport.nim +++ b/libp2p/protocols/connectivity/relay/rtransport.nim @@ -29,7 +29,9 @@ type RelayTransport* = ref object of Transport queue: AsyncQueue[Connection] selfRunning: bool -method start*(self: RelayTransport, ma: seq[MultiAddress]) {.async.} = +method start*( + self: RelayTransport, ma: seq[MultiAddress] +) {.async: (raises: [LPError, transport.TransportError]).} = if self.selfRunning: trace "Relay transport already running" return @@ -43,12 +45,15 @@ method start*(self: RelayTransport, ma: seq[MultiAddress]) {.async.} = await procCall Transport(self).start(ma) trace "Starting Relay transport" -method stop*(self: RelayTransport) {.async.} = +method stop*(self: RelayTransport) {.async: (raises: []).} = self.running = false self.selfRunning = false self.client.onNewConnection = nil while not self.queue.empty(): - await self.queue.popFirstNoWait().close() + try: + await self.queue.popFirstNoWait().close() + except AsyncQueueEmptyError: + continue # checked with self.queue.empty() method accept*(self: RelayTransport): Future[Connection] {.async.} = result = await self.queue.popFirst() diff --git a/libp2p/transports/quictransport.nim b/libp2p/transports/quictransport.nim index 6a07aa6dc3..400dc9bb93 100644 --- a/libp2p/transports/quictransport.nim +++ b/libp2p/transports/quictransport.nim @@ -21,6 +21,7 @@ logScope: type P2PConnection = connection.Connection QuicConnection = quic.Connection + QuicTransportError* = object of transport.TransportError # Stream type QuicStream* = ref object of P2PConnection @@ -148,22 +149,30 @@ method handles*(transport: QuicTransport, address: MultiAddress): bool = return false QUIC_V1.match(address) -method start*(transport: QuicTransport, addrs: seq[MultiAddress]) {.async.} = - doAssert transport.listener.isNil, "start() already called" +method start*( + self: QuicTransport, addrs: seq[MultiAddress] +) {.async: (raises: [LPError, transport.TransportError]).} = + doAssert self.listener.isNil, "start() already called" #TODO handle multiple addr - transport.listener = listen(initTAddress(addrs[0]).tryGet) - await procCall Transport(transport).start(addrs) - transport.addrs[0] = - MultiAddress.init(transport.listener.localAddress(), IPPROTO_UDP).tryGet() & - MultiAddress.init("/quic-v1").get() - transport.running = true - -method stop*(transport: QuicTransport) {.async.} = + try: + self.listener = listen(initTAddress(addrs[0]).tryGet) + await procCall Transport(self).start(addrs) + self.addrs[0] = + MultiAddress.init(self.listener.localAddress(), IPPROTO_UDP).tryGet() & + MultiAddress.init("/quic-v1").get() + except TransportOsError as exc: + raise (ref QuicTransportError)(msg: exc.msg, parent: exc) + self.running = true + +method stop*(transport: QuicTransport) {.async: (raises: []).} = if transport.running: for c in transport.connections: await c.close() await procCall Transport(transport).stop() - await transport.listener.stop() + try: + await transport.listener.stop() + except CatchableError as exc: + trace "Error shutting down Quic transport", description = exc.msg transport.running = false transport.listener = nil diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 91a01583dd..42fa186b2a 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -105,113 +105,97 @@ proc new*( connectionsTimeout: connectionsTimeout, ) -method start*(self: TcpTransport, addrs: seq[MultiAddress]): Future[void] = +method start*( + self: TcpTransport, addrs: seq[MultiAddress] +): Future[void] {.async: (raises: [LPError, transport.TransportError]).} = ## Start transport listening to the given addresses - for dial-only transports, ## start with an empty list - # TODO remove `impl` indirection throughout when `raises` is added to base + if self.running: + warn "TCP transport already running" + return - proc impl( - self: TcpTransport, addrs: seq[MultiAddress] - ): Future[void] {.async: (raises: [transport.TransportError, CancelledError]).} = - if self.running: - warn "TCP transport already running" - return - - trace "Starting TCP transport" - - self.flags.incl(ServerFlags.ReusePort) - - var supported: seq[MultiAddress] - var initialized = false - try: - for i, ma in addrs: - if not self.handles(ma): - trace "Invalid address detected, skipping!", address = ma - continue - - let - ta = initTAddress(ma).expect("valid address per handles check above") - server = - try: - createStreamServer(ta, flags = self.flags) - except common.TransportError as exc: - raise (ref TcpTransportError)(msg: exc.msg, parent: exc) - - self.servers &= server - - trace "Listening on", address = ma - supported.add( - MultiAddress.init(server.sock.getLocalAddress()).expect( - "Can init from local address" - ) - ) - - initialized = true - finally: - if not initialized: - # Clean up partial success on exception - await noCancel allFutures(self.servers.mapIt(it.closeWait())) - reset(self.servers) + trace "Starting TCP transport" - try: - await procCall Transport(self).start(supported) - except CatchableError: - raiseAssert "Base method does not raise" + self.flags.incl(ServerFlags.ReusePort) - trackCounter(TcpTransportTrackerName) + var supported: seq[MultiAddress] + var initialized = false + try: + for i, ma in addrs: + if not self.handles(ma): + trace "Invalid address detected, skipping!", address = ma + continue - impl(self, addrs) + let + ta = initTAddress(ma).expect("valid address per handles check above") + server = + try: + createStreamServer(ta, flags = self.flags) + except common.TransportError as exc: + raise (ref TcpTransportError)(msg: exc.msg, parent: exc) -method stop*(self: TcpTransport): Future[void] = - ## Stop the transport and close all connections it created - proc impl(self: TcpTransport) {.async: (raises: []).} = - trace "Stopping TCP transport" - self.stopping = true - defer: - self.stopping = false + self.servers &= server - if self.running: - # Reset the running flag - try: - await noCancel procCall Transport(self).stop() - except CatchableError: # TODO remove when `accept` is annotated with raises - raiseAssert "doesn't actually raise" - - # Stop each server by closing the socket - this will cause all accept loops - # to fail - since the running flag has been reset, it's also safe to close - # all known clients since no more of them will be added - await noCancel allFutures( - self.servers.mapIt(it.closeWait()) & - self.clients[Direction.In].mapIt(it.closeWait()) & - self.clients[Direction.Out].mapIt(it.closeWait()) + trace "Listening on", address = ma + supported.add( + MultiAddress.init(server.sock.getLocalAddress()).expect( + "Can init from local address" + ) ) - self.servers = @[] - - for acceptFut in self.acceptFuts: - if acceptFut.completed(): - await acceptFut.value().closeWait() - self.acceptFuts = @[] - - if self.clients[Direction.In].len != 0 or self.clients[Direction.Out].len != 0: - # Future updates could consider turning this warn into an assert since - # it should never happen if the shutdown code is correct - warn "Couldn't clean up clients", - len = self.clients[Direction.In].len + self.clients[Direction.Out].len - - trace "Transport stopped" - untrackCounter(TcpTransportTrackerName) - else: - # For legacy reasons, `stop` on a transpart that wasn't started is - # expected to close outgoing connections created by the transport - warn "TCP transport already stopped" - - doAssert self.clients[Direction.In].len == 0, - "No incoming connections possible without start" - await noCancel allFutures(self.clients[Direction.Out].mapIt(it.closeWait())) + initialized = true + finally: + if not initialized: + # Clean up partial success on exception + await noCancel allFutures(self.servers.mapIt(it.closeWait())) + reset(self.servers) + + await procCall Transport(self).start(supported) + + trackCounter(TcpTransportTrackerName) + +method stop*(self: TcpTransport): Future[void] {.async: (raises: []).} = + trace "Stopping TCP transport" + self.stopping = true + defer: + self.stopping = false + + if self.running: + # Reset the running flag + await noCancel procCall Transport(self).stop() + # Stop each server by closing the socket - this will cause all accept loops + # to fail - since the running flag has been reset, it's also safe to close + # all known clients since no more of them will be added + await noCancel allFutures( + self.servers.mapIt(it.closeWait()) & + self.clients[Direction.In].mapIt(it.closeWait()) & + self.clients[Direction.Out].mapIt(it.closeWait()) + ) - impl(self) + self.servers = @[] + + for acceptFut in self.acceptFuts: + if acceptFut.completed(): + await acceptFut.value().closeWait() + self.acceptFuts = @[] + + if self.clients[Direction.In].len != 0 or self.clients[Direction.Out].len != 0: + # Future updates could consider turning this warn into an assert since + # it should never happen if the shutdown code is correct + warn "Couldn't clean up clients", + len = self.clients[Direction.In].len + self.clients[Direction.Out].len + + trace "Transport stopped" + untrackCounter(TcpTransportTrackerName) + else: + # For legacy reasons, `stop` on a transpart that wasn't started is + # expected to close outgoing connections created by the transport + warn "TCP transport already stopped" + + doAssert self.clients[Direction.In].len == 0, + "No incoming connections possible without start" + await noCancel allFutures(self.clients[Direction.Out].mapIt(it.closeWait())) method accept*(self: TcpTransport): Future[Connection] = ## accept a new TCP connection, returning nil on non-fatal errors diff --git a/libp2p/transports/tortransport.nim b/libp2p/transports/tortransport.nim index f98fa83840..0f74fc3d93 100644 --- a/libp2p/transports/tortransport.nim +++ b/libp2p/transports/tortransport.nim @@ -222,7 +222,9 @@ method dial*( await transp.closeWait() raise err -method start*(self: TorTransport, addrs: seq[MultiAddress]) {.async.} = +method start*( + self: TorTransport, addrs: seq[MultiAddress] +) {.async: (raises: [LPError, transport.TransportError]).} = ## listen on the transport ## @@ -254,7 +256,7 @@ method accept*(self: TorTransport): Future[Connection] {.async.} = conn.observedAddr = Opt.none(MultiAddress) return conn -method stop*(self: TorTransport) {.async.} = +method stop*(self: TorTransport) {.async: (raises: []).} = ## stop the transport ## await procCall Transport(self).stop() # call base diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index 94c605eb72..cd7aff5e4e 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -39,7 +39,9 @@ type proc newTransportClosedError*(parent: ref Exception = nil): ref TransportError = newException(TransportClosedError, "Transport closed, no more connections!", parent) -method start*(self: Transport, addrs: seq[MultiAddress]) {.base, async.} = +method start*( + self: Transport, addrs: seq[MultiAddress] +) {.base, async: (raises: [LPError, TransportError]).} = ## start the transport ## @@ -47,7 +49,7 @@ method start*(self: Transport, addrs: seq[MultiAddress]) {.base, async.} = self.addrs = addrs self.running = true -method stop*(self: Transport) {.base, async.} = +method stop*(self: Transport) {.base, async: (raises: []).} = ## stop and cleanup the transport ## including all outstanding connections ## diff --git a/libp2p/transports/wstransport.nim b/libp2p/transports/wstransport.nim index 9eb8a340d4..c1e2cce1c0 100644 --- a/libp2p/transports/wstransport.nim +++ b/libp2p/transports/wstransport.nim @@ -34,8 +34,11 @@ export transport, websock, results const DefaultHeadersTimeout = 3.seconds -type WsStream = ref object of Connection - session: WSSession +type + WsStream = ref object of Connection + session: WSSession + + WsTransportError* = object of transport.TransportError method initStream*(s: WsStream) = if s.objName.len == 0: @@ -116,7 +119,9 @@ type WsTransport* = ref object of Transport proc secure*(self: WsTransport): bool = not (isNil(self.tlsPrivateKey) or isNil(self.tlsCertificate)) -method start*(self: WsTransport, addrs: seq[MultiAddress]) {.async.} = +method start*( + self: WsTransport, addrs: seq[MultiAddress] +) {.async: (raises: [LPError, transport.TransportError]).} = ## listen on the transport ## @@ -140,19 +145,22 @@ method start*(self: WsTransport, addrs: seq[MultiAddress]) {.async.} = else: false + let address = ma.initTAddress().tryGet() + let httpserver = - if isWss: - TlsHttpServer.create( - address = ma.initTAddress().tryGet(), - tlsPrivateKey = self.tlsPrivateKey, - tlsCertificate = self.tlsCertificate, - flags = self.flags, - handshakeTimeout = self.handshakeTimeout, - ) - else: - HttpServer.create( - ma.initTAddress().tryGet(), handshakeTimeout = self.handshakeTimeout - ) + try: + if isWss: + TlsHttpServer.create( + address = address, + tlsPrivateKey = self.tlsPrivateKey, + tlsCertificate = self.tlsCertificate, + flags = self.flags, + handshakeTimeout = self.handshakeTimeout, + ) + else: + HttpServer.create(address, handshakeTimeout = self.handshakeTimeout) + except CatchableError as exc: + raise (ref WsTransportError)(msg: exc.msg, parent: exc) self.httpservers &= httpserver @@ -173,7 +181,7 @@ method start*(self: WsTransport, addrs: seq[MultiAddress]) {.async.} = self.running = true -method stop*(self: WsTransport) {.async.} = +method stop*(self: WsTransport) {.async: (raises: []).} = ## stop the transport ##