diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Connector/DelegatedConnectionTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Connector/DelegatedConnectionTests.cs index 29546f150..d37ff9684 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/Connector/DelegatedConnectionTests.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Connector/DelegatedConnectionTests.cs @@ -36,9 +36,14 @@ public TestDelegatedConnection(IConnection connection) : base(connection) } public override void OnError(Exception error) + { + throw new NotImplementedException(); + } + + public override Task OnErrorAsync(Exception error) { ErrorList.Add(error); - throw error; + return Task.FromException(error); } } diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Routing/ClusterConnectionPoolTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Routing/ClusterConnectionPoolTests.cs index 7d5660b28..ff7a9af0e 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/Routing/ClusterConnectionPoolTests.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Routing/ClusterConnectionPoolTests.cs @@ -206,7 +206,7 @@ public void ShouldDeactivateServerPoolIfNotPresentInNewServersButHasInUseConnect public class AddMethod { [Fact] - public void ShouldActivateIfExist() + public void ShouldDeactivateIfExist() { // Given var mockedConnectionPool = new Mock(); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs index 5bb342898..97bbdda3a 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs @@ -33,9 +33,10 @@ protected DelegatedConnection(IConnection connection) public abstract void OnError(Exception error); - private void OnError(AggregateException error) + public virtual Task OnErrorAsync(Exception error) { - OnError(error.GetBaseException()); + OnError(error); + return TaskExtensions.GetCompletedTask(); } public void Sync() @@ -167,45 +168,16 @@ public virtual Task CloseAsync() return Delegate.CloseAsync(); } - internal Task TaskWithErrorHandling(Func task) + internal async Task TaskWithErrorHandling(Func task) { - var tcs = new TaskCompletionSource(); - try { - task().ContinueWith(t => - { - if (t.IsFaulted) - { - try - { - OnError(t.Exception); - } - catch (AggregateException exc) - { - tcs.SetException(exc.GetBaseException()); - } - catch (Exception exc) - { - tcs.SetException(exc); - } - } - else if (t.IsCanceled) - { - tcs.SetCanceled(); - } - else - { - tcs.SetResult(true); - } - }, TaskContinuationOptions.ExecuteSynchronously); - } - catch (Exception e) // this is to catch whatever direct error in `task()` before returning a task + await task().ConfigureAwait(false); + } + catch (Exception e) { - OnError(e); + await OnErrorAsync(e).ConfigureAwait(false); } - - return tcs.Task; } } } diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnection.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnection.cs index c45506555..481889370 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnection.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnection.cs @@ -55,7 +55,7 @@ public override void OnError(Exception error) throw error; } - public async Task OnErrorAsync(Exception error) + public override async Task OnErrorAsync(Exception error) { if (error is ServiceUnavailableException) { diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnectionPool.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnectionPool.cs index 076d510fa..a93691f9b 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnectionPool.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Routing/ClusterConnectionPool.cs @@ -142,7 +142,6 @@ public void Update(IEnumerable added, IEnumerable removed) public async Task UpdateAsync(IEnumerable added, IEnumerable removed) { await AddAsync(added).ConfigureAwait(false); - // TODO chain this part and use task.waitAll foreach (var uri in removed) { if (_pools.TryGetValue(uri, out var pool))