Skip to content

Commit

Permalink
RPC: Close TCP connection on shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
Muhammed Kadir Yücel authored and hewison-chris committed May 9, 2022
1 parent ae54d09 commit 09895d4
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 89 deletions.
6 changes: 6 additions & 0 deletions source/agora/api/FullNode.d
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import agora.consensus.data.PreImageInfo;
import agora.consensus.data.Transaction;
import agora.consensus.data.ValidatorInfo;
import agora.crypto.Key;
import agora.network.RPC : noRPCRoute;

import vibe.data.serialization;
import vibe.http.common;
Expand Down Expand Up @@ -429,4 +430,9 @@ public interface API
***************************************************************************/

public WrappedConsensusParams getConsensusParams ();

/// Shutdown the node and its connections to other peers
@noRPCRoute
@noRoute
public void shutdown ();
}
8 changes: 8 additions & 0 deletions source/agora/network/Client.d
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ public class NetworkClient
public void shutdown () @safe nothrow
{
this.gossip_timer.stop();
foreach (conn; this.connections)
{
try
conn.api.shutdown();
catch (Exception ex)
log.dbg("Connection ({}) failed to close, continuing: {}", conn.address, ex);
}
}

/// For gossiping we don't want to block the calling fiber, so we use
Expand Down Expand Up @@ -654,6 +661,7 @@ public class NetworkClient
{
try
{
conn.api.shutdown();
this.connections = this.connections.remove!(c => c == conn);
this.connections_version++;
log.warn("Removing banned address {} while performing {}, addresses left: {}",
Expand Down
21 changes: 15 additions & 6 deletions source/agora/network/Manager.d
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public class NetworkManager

/// Called when we've connected and determined if this is
/// a FullNode / Validator
public alias OnHandshakeComplete = void delegate (
public alias OnHandshakeComplete = bool delegate (
in Address, agora.api.Validator.API, in Hash, in PublicKey);
/// Ditto
private OnHandshakeComplete onHandshakeComplete;
Expand Down Expand Up @@ -206,19 +206,25 @@ public class NetworkManager
}
}

this.onHandshakeComplete(this.address, this.api, id.utxo, id.key);
if(!this.onHandshakeComplete(this.address, this.api, id.utxo, id.key))
this.api.shutdown();

return;
}
catch (Exception ex)
{
if (!this.onFailedRequest(this.address))
{
this.api.shutdown();
return;
}

// else try again
this.outer.taskman.wait(this.outer.config.node.retry_delay);
}
}
// failed to connect, try to ban (if not whitelisted)
this.api.shutdown();
this.outer.banman.ban(address);
}
}
Expand Down Expand Up @@ -316,7 +322,8 @@ public class NetworkManager
}

/// Called after a node's handshake is complete
private void onHandshakeComplete (
/// returns true when a client is added for the peer
private bool onHandshakeComplete (
in Address address, agora.api.Validator.API api,
in Hash utxo, in PublicKey key)
{
Expand All @@ -342,22 +349,22 @@ public class NetworkManager
if (prev.identity.utxo !is Hash.init)
{
if (utxo != prev.identity.utxo)
return; // Drop it
return false; // Drop it
}
else if (utxo !is Hash.init)
{
prev.setIdentity(utxo, key);
this.required_peers.remove(utxo);
}
prev.merge(address, api);
return; // All done
return true; // All done
}
}

if (utxo !is Hash.init)
this.required_peers.remove(utxo);
else if (this.peerLimitReached())
return;
return false;

auto client = new NetworkClient(this.taskman, this.banman,
this.config.node.retry_delay, this.config.node.max_retries);
Expand All @@ -384,6 +391,8 @@ public class NetworkManager
auto node_info = client.getNodeInfo();
this.addAddresses(node_info.addresses);
}

return true;
}

/***************************************************************************
Expand Down
183 changes: 105 additions & 78 deletions source/agora/network/RPC.d
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ struct ProxyProtocol
}
}

/// Methods marked with this attribute will not be treated as RPC endpoint
package struct NoRPCRouteAttribute {}

/// Ditto
@property NoRPCRouteAttribute noRPCRoute()
{
if (!__ctfe)
assert(0, "noRPCRoute must be used as an attribute");
return NoRPCRouteAttribute.init;
}

/// A RPC packet
private struct Packet
{
Expand Down Expand Up @@ -225,6 +236,17 @@ public class RPCClient (API) : API
format("{}.{}.{}", __MODULE__, this.config.host, this.config.port));
}

~this ()
{
this.conn.close();
}

/// Close connection to the peer
public override void shutdown () @safe nothrow
{
this.conn.close();
}

/// Returns: A new `TCPConnection`
private TCPConnection connect () @safe
{
Expand Down Expand Up @@ -261,104 +283,109 @@ public class RPCClient (API) : API

/// Implementation of the API's functions
static foreach (member; __traits(allMembers, API))
{
static foreach (ovrld; __traits(getOverloads, API, member))
{
mixin(q{
override ReturnType!(ovrld) } ~ member ~ q{ (Parameters!ovrld params)
{
ubyte[1024] buffer = void;
scope DeserializeDg reader = (size_t size) @safe
static if (!hasUDA!(ovrld, NoRPCRouteAttribute))
{
mixin(q{
override ReturnType!(ovrld) } ~ member ~ q{ (Parameters!ovrld params)
{
ensure(size < buffer.length, "Out of bound read");
this.conn.read(buffer[0 .. size]);
return buffer[0 .. size];
};
ubyte[1024] buffer = void;
scope DeserializeDg reader = (size_t size) @safe
{
ensure(size < buffer.length, "Out of bound read");
this.conn.read(buffer[0 .. size]);
return buffer[0 .. size];
};

Packet packet;
packet.seq_id = this.seq_id++;
packet.method = this.lookup[ovrld.mangleof];
Packet packet;
packet.seq_id = this.seq_id++;
packet.method = this.lookup[ovrld.mangleof];

// Acquire the write lock and send the packet
{
this.wlock.lock();
scope (exit) this.wlock.unlock();
if (!this.conn.connected())
this.conn = this.connect();

this.conn.write(serializeFull(packet));
// List of parameters
foreach (ref p; params)
this.conn.write(serializeFull(p));
this.conn.flush();
}
// Acquire the write lock and send the packet
{
this.wlock.lock();
scope (exit) this.wlock.unlock();
if (!this.conn.connected())
this.conn = this.connect();

this.conn.write(serializeFull(packet));
// List of parameters
foreach (ref p; params)
this.conn.write(serializeFull(p));
this.conn.flush();
}

static if (!is(typeof(return) == void))
{
scope (exit) this.waiting_list.remove(packet.seq_id);
ReturnType!(ovrld)[] response;
auto woke_up = 0;
auto start = MonoTime.currTime;
Waiting waiting;
// Attempt to acquire the read lock, if we cant; register ourself
// as a `waiter` and wait for the Fiber that has the read lock to signal
// us when it receives the response we are waiting for
while (!this.rlock.tryLock())
static if (!is(typeof(return) == void))
{
if (waiting is null)
scope (exit) this.waiting_list.remove(packet.seq_id);
ReturnType!(ovrld)[] response;
auto woke_up = 0;
auto start = MonoTime.currTime;
Waiting waiting;
// Attempt to acquire the read lock, if we cant; register ourself
// as a `waiter` and wait for the Fiber that has the read lock to signal
// us when it receives the response we are waiting for
while (!this.rlock.tryLock())
{
waiting = new Waiting(createManualEvent(), () {
auto val = deserializeFull!(ReturnType!(ovrld))(reader);
response ~= val;
});
this.waiting_list[packet.seq_id] = waiting;
}
if (waiting is null)
{
waiting = new Waiting(createManualEvent(), () {
auto val = deserializeFull!(ReturnType!(ovrld))(reader);
response ~= val;
});
this.waiting_list[packet.seq_id] = waiting;
}

ensure(waiting.event.wait(start + this.config.read_timeout
- MonoTime.currTime, woke_up) > woke_up++, "Request timed out");
ensure(waiting.event.wait(start + this.config.read_timeout
- MonoTime.currTime, woke_up) > woke_up++, "Request timed out");

// reader fiber read the response and stored it for us
if (waiting.res.is_response)
{
ensure(waiting.res.method == packet.method, "Method mismatch");
ensure(response.length == 1, "Error while reading response");
return response[0];
// reader fiber read the response and stored it for us
if (waiting.res.is_response)
{
ensure(waiting.res.method == packet.method, "Method mismatch");
ensure(response.length == 1, "Error while reading response");
return response[0];
}
}
}

// got the rlock
// keep reading response packets and waking up the fibers waiting for them
{
scope (success)
// wake up one of the waiters to get the rlock and start reading
if (this.waiting_list.length > 0)
this.waiting_list.byValue.front.event.emit();
scope (exit) this.rlock.unlock();
scope (failure) this.conn.close();

ensure(this.conn.connected(), "Connection dropped");

while (true)
// got the rlock
// keep reading response packets and waking up the fibers waiting for them
{
auto any_response = deserializeFull!Packet(reader);
ensure(any_response.is_response, "Unexpected request on client socket");
scope (success)
// wake up one of the waiters to get the rlock and start reading
if (this.waiting_list.length > 0)
this.waiting_list.byValue.front.event.emit();
scope (exit) this.rlock.unlock();
scope (failure) this.conn.close();

if (any_response.seq_id == packet.seq_id)
{
ensure(any_response.method == packet.method, "Method mismatch");
return deserializeFull!(ReturnType!(ovrld))(reader);
}
else if (auto waiter = any_response.seq_id in this.waiting_list)
ensure(this.conn.connected(), "Connection dropped");

while (true)
{
(*waiter).res = any_response;
(*waiter).on_packet_received();
(*waiter).event.emit();
auto any_response = deserializeFull!Packet(reader);
ensure(any_response.is_response, "Unexpected request on client socket");

if (any_response.seq_id == packet.seq_id)
{
ensure(any_response.method == packet.method, "Method mismatch");
return deserializeFull!(ReturnType!(ovrld))(reader);
}
else if (auto waiter = any_response.seq_id in this.waiting_list)
{
(*waiter).res = any_response;
(*waiter).on_packet_received();
(*waiter).event.emit();
}
}
}
}
}
}
});
});
}
}
}
}

/// Aggregate configuration options for `RPCClient`
Expand Down
18 changes: 16 additions & 2 deletions source/agora/network/VibeManager.d
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ shared static this ()
/// And implementation of `agora.network.Manager : NetworkManager` using Vibe.d
public final class VibeNetworkManager : NetworkManager
{
/// A stub class for overriding `shutdown` for the REST client
private class ValidatorRestClient : RestInterfaceClient!(agora.api.Validator.API)
{
this (RestInterfaceSettings settings)
{
super(settings);
}

public override void shutdown ()
{
// REST client doesn't need a shutdown
return;
}
}

/// Construct an instance of this object
public this (in Config config, ManagedDatabase cache, ITaskManager taskman,
Clock clock, agora.api.FullNode.API owner_node, Ledger ledger)
Expand All @@ -83,8 +98,7 @@ public final class VibeNetworkManager : NetworkManager
timeout, timeout, timeout);

if (url.schema.startsWith("http"))
return new RestInterfaceClient!(agora.api.Validator.API)(
this.makeRestInterfaceSettings(url));
return new ValidatorRestClient(this.makeRestInterfaceSettings(url));

assert(0, "Unknown agora schema: " ~ url.toString());
}
Expand Down
2 changes: 1 addition & 1 deletion source/agora/node/FullNode.d
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ public class FullNode : API
***************************************************************************/

public void shutdown () @safe
public override void shutdown () @safe
{
this.is_shutting_down = true;
log.info("Shutting down..");
Expand Down
Loading

0 comments on commit 09895d4

Please sign in to comment.