Skip to content

Commit

Permalink
[Java] Fix connection race + Mockito error.
Browse files Browse the repository at this point in the history
  • Loading branch information
vyazelenko committed Apr 11, 2024
1 parent 262c945 commit bf33269
Showing 1 changed file with 19 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.agrona.DirectBuffer;
import org.agrona.ErrorHandler;
import org.agrona.LangUtil;
import org.agrona.collections.MutableLong;
import org.agrona.concurrent.AgentInvoker;
import org.agrona.concurrent.QueuedPipe;
import org.agrona.concurrent.status.CountersReader;
Expand All @@ -32,6 +33,7 @@
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;
import uk.co.real_logic.artio.CloseChecker;
import uk.co.real_logic.artio.FixCounters;
Expand Down Expand Up @@ -67,6 +69,7 @@
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import static uk.co.real_logic.artio.CommonConfiguration.DEFAULT_NAME_PREFIX;
import static uk.co.real_logic.artio.GatewayProcess.NO_CONNECTION_ID;
import static uk.co.real_logic.artio.Timing.assertEventuallyTrue;
import static uk.co.real_logic.artio.engine.FixEngine.ENGINE_LIBRARY_ID;
import static uk.co.real_logic.artio.library.FixLibrary.NO_MESSAGE_REPLAY;
Expand Down Expand Up @@ -135,7 +138,7 @@ public class FramerTest

private Framer framer;

private final ArgumentCaptor<Long> connectionId = ArgumentCaptor.forClass(Long.class);
private final MutableLong connectionId = new MutableLong(NO_CONNECTION_ID);
private final ErrorHandler errorHandler = mock(ErrorHandler.class);

@Before
Expand All @@ -150,15 +153,19 @@ public void setUp() throws IOException
when(outboundLibrarySubscription.imageBySessionId(anyInt())).thenReturn(normalImage);

when(mockEndPointFactory.receiverEndPoint(
any(), connectionId.capture(), anyLong(), anyInt(), anyInt(), any()))
.thenReturn(mockReceiverEndPoint);
any(), anyLong(), anyLong(), anyInt(), anyInt(), any()))
.thenAnswer((Answer<FixReceiverEndPoint>)invocationOnMock ->
{
connectionId.set(invocationOnMock.getArgument(1));
return mockReceiverEndPoint;
});

when(mockEndPointFactory.senderEndPoint(any(), anyLong(), anyInt(), any(), any()))
.thenReturn(mockSenderEndPoint);

when(mockReceiverEndPoint.connectionId()).then((inv) -> connectionId.getValue());
when(mockReceiverEndPoint.connectionId()).then((inv) -> connectionId.get());

when(mockSenderEndPoint.connectionId()).then((inv) -> connectionId.getValue());
when(mockSenderEndPoint.connectionId()).then((inv) -> connectionId.get());

when(gatewaySession.session()).thenReturn(session);
when(gatewaySession.fixDictionary()).thenReturn(fixDictionary);
Expand Down Expand Up @@ -278,7 +285,7 @@ public void shouldCloseSocketUponDisconnect() throws Exception
aClientConnects();
framer.doWork();

framer.onDisconnect(LIBRARY_ID, connectionId.getValue(), APPLICATION_DISCONNECT);
framer.onDisconnect(LIBRARY_ID, connectionId.get(), APPLICATION_DISCONNECT);
framer.doWork();

verifyEndPointsDisconnected(APPLICATION_DISCONNECT);
Expand All @@ -298,6 +305,7 @@ public void shouldNotConnectIfLibraryUnknown() throws Exception
framer.doWork();

assertNull("Sender has connected to server", server.accept());
assertEquals(NO_CONNECTION_ID, connectionId.get());
verifyErrorPublished(UNKNOWN_LIBRARY);
}

Expand Down Expand Up @@ -357,7 +365,6 @@ public void shouldIdentifyDuplicateInitiatedSessions() throws Exception
assertEquals(CONTINUE, onInitiateConnection());

verifyErrorPublished(DUPLICATE_SESSION);
assertNull(server.accept());
}

@Test
Expand Down Expand Up @@ -730,7 +737,7 @@ private void releaseConnection(final Action expectedResult)
{
assertEquals(expectedResult, framer.onReleaseSession(
LIBRARY_ID,
connectionId.getValue(),
connectionId.get(),
SESSION_ID,
CORR_ID,
ACTIVE,
Expand All @@ -749,7 +756,7 @@ private Action onLibraryConnect()

private void givenAGatewayToManage()
{
when(gatewaySession.connectionId()).thenReturn(connectionId.getValue());
when(gatewaySession.connectionId()).thenReturn(connectionId.get());
when(gatewaySession.sessionKey()).thenReturn(mock(CompositeKey.class));
when(gatewaySessions.sessions()).thenReturn(singletonList(gatewaySession));
}
Expand Down Expand Up @@ -843,13 +850,10 @@ private void initiateConnection() throws Exception

assertEquals(CONTINUE, onInitiateConnection());

do
while (NO_CONNECTION_ID == connectionId.get())
{
framer.doWork();
}
while (server.accept() == null);

assertNotNull("Connection not completed yet", connectionId.getValue());
}

private Action onInitiateConnection()
Expand Down Expand Up @@ -900,7 +904,7 @@ private void notifyLibraryOfConnection()
private void notifyLibraryOfConnection(final VerificationMode times)
{
verify(inboundPublication, times).saveManageSession(eq(LIBRARY_ID),
eq(connectionId.getValue()),
eq(connectionId.get()),
anyLong(),
anyInt(),
anyInt(),
Expand Down Expand Up @@ -943,7 +947,7 @@ private void notifyLibraryOfConnection(final VerificationMode times)
private void verifySessionExistsSaved(final VerificationMode times, final SessionStatus status)
{
verify(inboundPublication, times).saveManageSession(eq(LIBRARY_ID),
eq(connectionId.getValue()),
eq(connectionId.get()),
anyLong(),
anyInt(),
anyInt(),
Expand Down

0 comments on commit bf33269

Please sign in to comment.