diff --git a/source/agora/api/FullNode.d b/source/agora/api/FullNode.d index 94e7b2f1251..5d0a51fd886 100644 --- a/source/agora/api/FullNode.d +++ b/source/agora/api/FullNode.d @@ -24,6 +24,7 @@ import agora.consensus.data.PreImageInfo; import agora.consensus.data.Transaction; import agora.consensus.data.ValidatorInfo; import agora.crypto.Key; +import agora.network.RPC : noRPCRoute; import vibe.data.serialization; import vibe.http.common; @@ -429,4 +430,9 @@ public interface API ***************************************************************************/ public WrappedConsensusParams getConsensusParams (); + + /// Shutdown the node and its connections to other peers + @noRPCRoute + @noRoute + public void shutdown (); } diff --git a/source/agora/network/Client.d b/source/agora/network/Client.d index 2f002758216..3609746fb33 100644 --- a/source/agora/network/Client.d +++ b/source/agora/network/Client.d @@ -195,6 +195,13 @@ public class NetworkClient public void shutdown () @safe nothrow { this.gossip_timer.stop(); + foreach (conn; this.connections) + { + try + conn.api.shutdown(); + catch (Exception ex) + log.dbg("Connection ({}) failed to close, continuing: {}", conn.address, ex); + } } /// For gossiping we don't want to block the calling fiber, so we use @@ -654,6 +661,7 @@ public class NetworkClient { try { + conn.api.shutdown(); this.connections = this.connections.remove!(c => c == conn); this.connections_version++; log.warn("Removing banned address {} while performing {}, addresses left: {}", diff --git a/source/agora/network/Manager.d b/source/agora/network/Manager.d index d80679626a4..dcd5ec59cec 100644 --- a/source/agora/network/Manager.d +++ b/source/agora/network/Manager.d @@ -78,7 +78,7 @@ public class NetworkManager /// Called when we've connected and determined if this is /// a FullNode / Validator - public alias OnHandshakeComplete = void delegate ( + public alias OnHandshakeComplete = bool delegate ( in Address, agora.api.Validator.API, in Hash, in PublicKey); /// Ditto private OnHandshakeComplete onHandshakeComplete; @@ -206,19 +206,25 @@ public class NetworkManager } } - this.onHandshakeComplete(this.address, this.api, id.utxo, id.key); + if(!this.onHandshakeComplete(this.address, this.api, id.utxo, id.key)) + this.api.shutdown(); + return; } catch (Exception ex) { if (!this.onFailedRequest(this.address)) + { + this.api.shutdown(); return; + } // else try again this.outer.taskman.wait(this.outer.config.node.retry_delay); } } // failed to connect, try to ban (if not whitelisted) + this.api.shutdown(); this.outer.banman.ban(address); } } @@ -316,7 +322,8 @@ public class NetworkManager } /// Called after a node's handshake is complete - private void onHandshakeComplete ( + /// returns true when a client is added for the peer + private bool onHandshakeComplete ( in Address address, agora.api.Validator.API api, in Hash utxo, in PublicKey key) { @@ -342,7 +349,7 @@ public class NetworkManager if (prev.identity.utxo !is Hash.init) { if (utxo != prev.identity.utxo) - return; // Drop it + return false; // Drop it } else if (utxo !is Hash.init) { @@ -350,14 +357,14 @@ public class NetworkManager this.required_peers.remove(utxo); } prev.merge(address, api); - return; // All done + return true; // All done } } if (utxo !is Hash.init) this.required_peers.remove(utxo); else if (this.peerLimitReached()) - return; + return false; auto client = new NetworkClient(this.taskman, this.banman, this.config.node.retry_delay, this.config.node.max_retries); @@ -384,6 +391,8 @@ public class NetworkManager auto node_info = client.getNodeInfo(); this.addAddresses(node_info.addresses); } + + return true; } /*************************************************************************** diff --git a/source/agora/network/RPC.d b/source/agora/network/RPC.d index 4e6ad8ce8ee..d37347d44b5 100644 --- a/source/agora/network/RPC.d +++ b/source/agora/network/RPC.d @@ -118,6 +118,17 @@ struct ProxyProtocol } } +/// Methods marked with this attribute will not be treated as RPC endpoint +package struct NoRPCRouteAttribute {} + +/// Ditto +@property NoRPCRouteAttribute noRPCRoute() +{ + if (!__ctfe) + assert(0, "noRPCRoute must be used as an attribute"); + return NoRPCRouteAttribute.init; +} + /// A RPC packet private struct Packet { @@ -225,6 +236,17 @@ public class RPCClient (API) : API format("{}.{}.{}", __MODULE__, this.config.host, this.config.port)); } + ~this () + { + this.conn.close(); + } + + /// Close connection to the peer + public override void shutdown () @safe nothrow + { + this.conn.close(); + } + /// Returns: A new `TCPConnection` private TCPConnection connect () @safe { @@ -261,104 +283,109 @@ public class RPCClient (API) : API /// Implementation of the API's functions static foreach (member; __traits(allMembers, API)) + { static foreach (ovrld; __traits(getOverloads, API, member)) { - mixin(q{ - override ReturnType!(ovrld) } ~ member ~ q{ (Parameters!ovrld params) - { - ubyte[1024] buffer = void; - scope DeserializeDg reader = (size_t size) @safe + static if (!hasUDA!(ovrld, NoRPCRouteAttribute)) + { + mixin(q{ + override ReturnType!(ovrld) } ~ member ~ q{ (Parameters!ovrld params) { - ensure(size < buffer.length, "Out of bound read"); - this.conn.read(buffer[0 .. size]); - return buffer[0 .. size]; - }; + ubyte[1024] buffer = void; + scope DeserializeDg reader = (size_t size) @safe + { + ensure(size < buffer.length, "Out of bound read"); + this.conn.read(buffer[0 .. size]); + return buffer[0 .. size]; + }; - Packet packet; - packet.seq_id = this.seq_id++; - packet.method = this.lookup[ovrld.mangleof]; + Packet packet; + packet.seq_id = this.seq_id++; + packet.method = this.lookup[ovrld.mangleof]; - // Acquire the write lock and send the packet - { - this.wlock.lock(); - scope (exit) this.wlock.unlock(); - if (!this.conn.connected()) - this.conn = this.connect(); - - this.conn.write(serializeFull(packet)); - // List of parameters - foreach (ref p; params) - this.conn.write(serializeFull(p)); - this.conn.flush(); - } + // Acquire the write lock and send the packet + { + this.wlock.lock(); + scope (exit) this.wlock.unlock(); + if (!this.conn.connected()) + this.conn = this.connect(); + + this.conn.write(serializeFull(packet)); + // List of parameters + foreach (ref p; params) + this.conn.write(serializeFull(p)); + this.conn.flush(); + } - static if (!is(typeof(return) == void)) - { - scope (exit) this.waiting_list.remove(packet.seq_id); - ReturnType!(ovrld)[] response; - auto woke_up = 0; - auto start = MonoTime.currTime; - Waiting waiting; - // Attempt to acquire the read lock, if we cant; register ourself - // as a `waiter` and wait for the Fiber that has the read lock to signal - // us when it receives the response we are waiting for - while (!this.rlock.tryLock()) + static if (!is(typeof(return) == void)) { - if (waiting is null) + scope (exit) this.waiting_list.remove(packet.seq_id); + ReturnType!(ovrld)[] response; + auto woke_up = 0; + auto start = MonoTime.currTime; + Waiting waiting; + // Attempt to acquire the read lock, if we cant; register ourself + // as a `waiter` and wait for the Fiber that has the read lock to signal + // us when it receives the response we are waiting for + while (!this.rlock.tryLock()) { - waiting = new Waiting(createManualEvent(), () { - auto val = deserializeFull!(ReturnType!(ovrld))(reader); - response ~= val; - }); - this.waiting_list[packet.seq_id] = waiting; - } + if (waiting is null) + { + waiting = new Waiting(createManualEvent(), () { + auto val = deserializeFull!(ReturnType!(ovrld))(reader); + response ~= val; + }); + this.waiting_list[packet.seq_id] = waiting; + } - ensure(waiting.event.wait(start + this.config.read_timeout - - MonoTime.currTime, woke_up) > woke_up++, "Request timed out"); + ensure(waiting.event.wait(start + this.config.read_timeout + - MonoTime.currTime, woke_up) > woke_up++, "Request timed out"); - // reader fiber read the response and stored it for us - if (waiting.res.is_response) - { - ensure(waiting.res.method == packet.method, "Method mismatch"); - ensure(response.length == 1, "Error while reading response"); - return response[0]; + // reader fiber read the response and stored it for us + if (waiting.res.is_response) + { + ensure(waiting.res.method == packet.method, "Method mismatch"); + ensure(response.length == 1, "Error while reading response"); + return response[0]; + } } - } - // got the rlock - // keep reading response packets and waking up the fibers waiting for them - { - scope (success) - // wake up one of the waiters to get the rlock and start reading - if (this.waiting_list.length > 0) - this.waiting_list.byValue.front.event.emit(); - scope (exit) this.rlock.unlock(); - scope (failure) this.conn.close(); - - ensure(this.conn.connected(), "Connection dropped"); - - while (true) + // got the rlock + // keep reading response packets and waking up the fibers waiting for them { - auto any_response = deserializeFull!Packet(reader); - ensure(any_response.is_response, "Unexpected request on client socket"); + scope (success) + // wake up one of the waiters to get the rlock and start reading + if (this.waiting_list.length > 0) + this.waiting_list.byValue.front.event.emit(); + scope (exit) this.rlock.unlock(); + scope (failure) this.conn.close(); - if (any_response.seq_id == packet.seq_id) - { - ensure(any_response.method == packet.method, "Method mismatch"); - return deserializeFull!(ReturnType!(ovrld))(reader); - } - else if (auto waiter = any_response.seq_id in this.waiting_list) + ensure(this.conn.connected(), "Connection dropped"); + + while (true) { - (*waiter).res = any_response; - (*waiter).on_packet_received(); - (*waiter).event.emit(); + auto any_response = deserializeFull!Packet(reader); + ensure(any_response.is_response, "Unexpected request on client socket"); + + if (any_response.seq_id == packet.seq_id) + { + ensure(any_response.method == packet.method, "Method mismatch"); + return deserializeFull!(ReturnType!(ovrld))(reader); + } + else if (auto waiter = any_response.seq_id in this.waiting_list) + { + (*waiter).res = any_response; + (*waiter).on_packet_received(); + (*waiter).event.emit(); + } } } } } - } - }); + }); + } } + } } /// Aggregate configuration options for `RPCClient` diff --git a/source/agora/network/VibeManager.d b/source/agora/network/VibeManager.d index b5c2106fda8..d870097705e 100644 --- a/source/agora/network/VibeManager.d +++ b/source/agora/network/VibeManager.d @@ -59,6 +59,21 @@ shared static this () /// And implementation of `agora.network.Manager : NetworkManager` using Vibe.d public final class VibeNetworkManager : NetworkManager { + /// A stub class for overriding `shutdown` for the REST client + private class ValidatorRestClient : RestInterfaceClient!(agora.api.Validator.API) + { + this (RestInterfaceSettings settings) + { + super(settings); + } + + public override void shutdown () + { + // REST client doesn't need a shutdown + return; + } + } + /// Construct an instance of this object public this (in Config config, ManagedDatabase cache, ITaskManager taskman, Clock clock, agora.api.FullNode.API owner_node, Ledger ledger) @@ -83,8 +98,7 @@ public final class VibeNetworkManager : NetworkManager timeout, timeout, timeout); if (url.schema.startsWith("http")) - return new RestInterfaceClient!(agora.api.Validator.API)( - this.makeRestInterfaceSettings(url)); + return new ValidatorRestClient(this.makeRestInterfaceSettings(url)); assert(0, "Unknown agora schema: " ~ url.toString()); } diff --git a/source/agora/node/FullNode.d b/source/agora/node/FullNode.d index ca53590b23d..1dd4f57039c 100644 --- a/source/agora/node/FullNode.d +++ b/source/agora/node/FullNode.d @@ -708,7 +708,7 @@ public class FullNode : API ***************************************************************************/ - public void shutdown () @safe + public override void shutdown () @safe { this.is_shutting_down = true; log.info("Shutting down.."); diff --git a/source/agora/test/Base.d b/source/agora/test/Base.d index a1861ffabd2..947d3cb94f9 100644 --- a/source/agora/test/Base.d +++ b/source/agora/test/Base.d @@ -1457,12 +1457,26 @@ public class TestNetworkManager : NetworkManager this.address = address; } + private class TestAPIClient(T) : RemoteAPI!(T) + { + this (geod24.LocalRest.Listener!T api, Duration timeout = 5.seconds) + { + super(api, timeout); + } + + public override void shutdown () + { + // Local endpoint doesn't need a shutdown + return; + } + } + /// protected final override TestAPI makeClient (Address address) { auto tid = this.registry.locate!TestAPI(address.host); if (tid != typeof(tid).init) - return new RemoteAPI!TestAPI(tid, this.config.node.timeout); + return new TestAPIClient!TestAPI(tid, this.config.node.timeout); throw (new Exception( format!"Trying to access node at address '%s' from '%s' without first creating it"(address, this.address))); } @@ -1520,6 +1534,8 @@ public class TestNetworkManager : NetworkManager } } +import geod24.LocalRest : noCommand; + /******************************************************************************* API implemented by the test nodes runs by LocalRest @@ -1701,6 +1717,9 @@ public interface TestAPI : API ***************************************************************************/ public Amount getPenaltyDeposit (Hash utxo); + + @noCommand + public void shutdown (); } /// Return type for `TestAPI.getUTXOs` @@ -1982,6 +2001,11 @@ public class TestFullNode : FullNode, TestAPI { assert(0); } + + public override void shutdown () + { + super.shutdown(); + } } /// A Validator which also implements test routines in TestAPI @@ -2055,6 +2079,11 @@ public class TestValidatorNode : Validator, TestAPI } return quorums; } + + public override void shutdown () + { + super.shutdown(); + } } /// Convenience mixin for deriving classes diff --git a/submodules/localrest b/submodules/localrest index dbbe810a323..1047aa5ab37 160000 --- a/submodules/localrest +++ b/submodules/localrest @@ -1 +1 @@ -Subproject commit dbbe810a3234a106d31b30a6335e465ca4a740c6 +Subproject commit 1047aa5ab373c5a08710ee9dbe5c5782b9329483