Skip to content

Commit

Permalink
fix the disconnect on oversized message fix #499
Browse files Browse the repository at this point in the history
- actually disconnect
- handle case where at most 2 bytes overflow
- handle back-pressure when saving invalid message
  • Loading branch information
wojciech-adaptive committed Dec 21, 2023
1 parent 076a5b4 commit d370dac
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ public final class TestMessages
public static final byte[] INVALID_LENGTH_MESSAGE = toAscii(
"8=FIX.4.4\0019=5\00135=A\00134=1\00149=TW\00152=20150604-12:46:54\00156=ISLD\00198=0\00110=000\001");

public static final byte[] OVERSIZED_MESSAGE_START =
toAscii("8=FIX.4.2\0019=99999\00135=D\00134=4\00149=ABC_DEFG01\001" +
"52=20090323-15:40:29\00156=CCG\001115=XYZ\00111=NF 0542/03232009\00154=1\00138=100\00155=CVS\00140=1");

public static final byte[] ZERO_CHECKSUM_MESSAGE = toAscii(
"8=FIX.4.4\0019=0067\00135=0\00149=acceptor\00156=initiator\00134=2" +
"\00152=20160415-12:50:23.294\001112=hi\00110=000\001");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ boolean retryFrameMessages()
// true - no more framed messages in the buffer data to process. This could mean no more messages, or some data
// that is an incomplete message.
// false - needs to be retried, aka back-pressured
@SuppressWarnings("MethodLength")
private boolean frameMessages(final long readTimestampInNs)
{
final MutableAsciiBuffer buffer = this.buffer;
Expand All @@ -407,6 +408,14 @@ private boolean frameMessages(final long readTimestampInNs)

try
{
// 8=FIX.4.4|9=60|35=5|49=initiator|56=acceptor|34=3|52=20231220-13:12:16.021|10=234|
// ^ ^ ^ ^ ^
// startOfBodyLength ---------+ | | | |
// endOfBodyLength -------------+ | | |
// startOfChecksumTag ----------------------------------------------------------------------+ | |
// endOfChecksumTag/startOfChecksumValue --------------------------------------------------------+ |
// endOfMessage -----------------------------------------------------------------------------------+

final int startOfBodyLength = scanForBodyLength(offset, readTimestampInNs);
if (startOfBodyLength < 0)
{
Expand All @@ -424,7 +433,10 @@ private boolean frameMessages(final long readTimestampInNs)
final int endOfChecksumTag = startOfChecksumTag + MIN_CHECKSUM_SIZE;
if (endOfChecksumTag >= usedBufferData)
{
disconnectOnOversizedMessage(offset, readTimestampInNs);
if (isMessageOversized(offset))
{
return saveOversizedMessageAndDisconnect(offset, readTimestampInNs);
}
break;
}

Expand All @@ -442,6 +454,10 @@ private boolean frameMessages(final long readTimestampInNs)
final int endOfMessage = scanEndOfMessage(startOfChecksumValue);
if (endOfMessage == UNKNOWN_INDEX)
{
if (isMessageOversized(offset))
{
return saveOversizedMessageAndDisconnect(offset, readTimestampInNs);
}
break; // Need more data
}

Expand Down Expand Up @@ -955,16 +971,27 @@ private boolean invalidateMessage(final int offset, final long readTimestamp)
return saveInvalidMessage(offset, readTimestamp);
}

private void disconnectOnOversizedMessage(final int offset, final long readTimestamp)
private boolean isMessageOversized(final int offset)
{
return offset == 0 && byteBuffer.remaining() == 0;
}

// returns false if back-pressured
private boolean saveOversizedMessageAndDisconnect(final int offset, final long readTimestamp)
{
if (offset == 0 && this.byteBuffer.remaining() == 0)
DebugLogger.log(FIX_MESSAGE, "Invalidated (oversized): ", buffer, offset, usedBufferData - offset);

if (saveInvalidMessage(offset, readTimestamp))
{
saveInvalidMessage(offset, readTimestamp);
errorHandler.onError(new Exception(String.format(
"Unable to frame message, receiver buffer too small. connectionId=%d",
connectionId)));
disconnectEndpoint(DisconnectReason.EXCEPTION);
return false;
}

errorHandler.onError(new Exception(
"Unable to frame message, receiver buffer too small. connectionId=" + connectionId));

completeDisconnect(DisconnectReason.EXCEPTION);

return true;
}

private boolean saveInvalidMessage(final int offset, final int length, final long readTimestamp)
Expand Down
36 changes: 36 additions & 0 deletions artio-core/src/test/java/uk/co/real_logic/artio/TestFixtures.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.aeron.driver.MediaDriver;
import org.agrona.IoUtil;
import org.agrona.concurrent.YieldingIdleStrategy;
import uk.co.real_logic.artio.util.MutableAsciiBuffer;

import java.io.File;
import java.io.IOException;
Expand All @@ -33,6 +34,7 @@
import java.util.Arrays;

import static io.aeron.driver.ThreadingMode.SHARED;
import static org.junit.jupiter.api.Assertions.assertEquals;

public final class TestFixtures
{
Expand Down Expand Up @@ -185,4 +187,38 @@ public static String largeTestReqId()

return new String(testReqIDChars);
}

public static byte[] largeMessage(final int messageLength)
{
// 8=FIX.4.4|9=00000|35=0|49=initiator|56=acceptor|34=2|52=20231220-13:12:16.020|112=...|10=...|

if (messageLength < 91)
{
throw new IllegalArgumentException(messageLength + " is not large enough");
}
final byte[] bytes = new byte[messageLength];
final MutableAsciiBuffer buffer = new MutableAsciiBuffer(bytes);
int index = buffer.putAscii(0, "8=FIX.4.4\0019=");
final int bodyLength = messageLength - (18 + 7); // header + trailer
if (bodyLength > 99_999)
{
throw new IllegalArgumentException(messageLength + " is too large");
}
buffer.putNaturalPaddedIntAscii(index, 5, bodyLength);
index += 5;
index += buffer.putAscii(index,
"\00135=0\00149=initiator\00156=acceptor\00134=2\00152=20231220-13:12:16.020\001112=");
while (index < messageLength - 8)
{
buffer.putCharAscii(index++, '+');
}
buffer.putSeparator(index++);
final int checksum = buffer.computeChecksum(0, index);
index += buffer.putAscii(index, "10=");
buffer.putNaturalPaddedIntAscii(index, 3, checksum);
index += 3;
buffer.putSeparator(index++);
assertEquals(messageLength, index);
return bytes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package uk.co.real_logic.artio;

import org.junit.jupiter.api.Test;
import uk.co.real_logic.artio.decoder.HeartbeatDecoder;
import uk.co.real_logic.artio.util.MutableAsciiBuffer;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

class TestFixturesTest
{
@Test
void shouldGenerateMessageOfGivenLength()
{
final int messageLength = 256;
final byte[] bytes = TestFixtures.largeMessage(messageLength);
assertEquals(messageLength, bytes.length);

final HeartbeatDecoder decoder = new HeartbeatDecoder();
final MutableAsciiBuffer buffer = new MutableAsciiBuffer(bytes);
decoder.decode(buffer, 0, bytes.length);
assertTrue(decoder.validate());
assertEquals(buffer.computeChecksum(0, messageLength - 7),
Integer.parseInt(decoder.trailer().checkSumAsString(), 10));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
import org.agrona.concurrent.status.AtomicCounter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.InOrder;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;
import uk.co.real_logic.artio.TestFixtures;
import uk.co.real_logic.artio.decoder.LogonDecoder;
import uk.co.real_logic.artio.dictionary.FixDictionary;
import uk.co.real_logic.artio.engine.FixEngine;
Expand All @@ -41,15 +44,14 @@
import java.nio.channels.ClosedChannelException;
import java.util.HashMap;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;

import static io.aeron.Publication.BACK_PRESSURED;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
import static uk.co.real_logic.artio.dictionary.ExampleDictionary.TAG_SPECIFIED_OUT_OF_REQUIRED_ORDER_MESSAGE_BYTES;
import static uk.co.real_logic.artio.engine.EngineConfiguration.NO_THROTTLE_WINDOW;
import static uk.co.real_logic.artio.messages.DisconnectReason.DUPLICATE_SESSION;
import static uk.co.real_logic.artio.messages.DisconnectReason.REMOTE_DISCONNECT;
import static uk.co.real_logic.artio.messages.DisconnectReason.EXCEPTION;
import static uk.co.real_logic.artio.messages.DisconnectReason.*;
import static uk.co.real_logic.artio.messages.MessageStatus.*;
import static uk.co.real_logic.artio.session.Session.UNKNOWN;
import static uk.co.real_logic.artio.util.TestMessages.*;
Expand Down Expand Up @@ -206,10 +208,16 @@ void shouldFrameValidFixMessage()
sessionReceivesOneMessage();
}

@Test
void shouldDetectOversizedFixMessage()
static IntStream overflowRange()
{
return IntStream.range(1, 10);
}

@ParameterizedTest
@MethodSource("overflowRange")
void shouldDetectOversizedFixMessage(final int overflow)
{
theEndpointReceivesTheStartOfAnOversizedMessage();
theEndpointReceivesTheStartOfAnOversizedMessage(overflow);

polls(BUFFER_SIZE);

Expand All @@ -220,6 +228,23 @@ void shouldDetectOversizedFixMessage()
sessionReceivesNoMessages();
}

@ParameterizedTest
@MethodSource("overflowRange")
void shouldHandleBackPressureWhenSavingOversizedMessage(final int overflow)
{
theEndpointReceivesTheStartOfAnOversizedMessage(overflow);

firstSaveAttemptIsBackPressured();
polls(-BUFFER_SIZE);
assertTrue(endPoint.retryFrameMessages());

savesInvalidMessage(BUFFER_SIZE, times(2), INVALID, TIMESTAMP);
verifyError(times(1));
verifyDisconnected(EXCEPTION);

sessionReceivesNoMessages();
}

@Test
void shouldFrameValidFixMessageWhenBackpressuredSelectionKeyCase()
{
Expand Down Expand Up @@ -637,16 +662,16 @@ private void theEndpointReceivesACompleteMessage()
theEndpointReceives(EG_MESSAGE, 0, MSG_LEN);
}

private void theEndpointReceivesTheStartOfAnOversizedMessage()
private void theEndpointReceivesTheStartOfAnOversizedMessage(final int overflow)
{
endpointBufferUpdatedWith(
(buffer) ->
{
final int paddingLength = buffer.capacity() - OVERSIZED_MESSAGE_START.length;
assertTrue(paddingLength > 0);
buffer.put(OVERSIZED_MESSAGE_START, 0, OVERSIZED_MESSAGE_START.length);
buffer.put(new byte[paddingLength], 0, paddingLength);
return buffer.capacity();
final int capacity = buffer.capacity();
final int messageLength = capacity + overflow;
final byte[] bytes = TestFixtures.largeMessage(messageLength);
buffer.put(bytes, 0, capacity);
return capacity;
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,20 @@ public void sendBytes(final byte[] bytes)
send(0, length);
}

public void sendBytesLarge(final byte[] bytes)
{
int offset = 0;
int remaining = bytes.length;
while (remaining > 0)
{
final int length = Math.min(remaining, BUFFER_SIZE);
writeAsciiBuffer.putBytes(0, bytes, offset, length);
send(0, length);
offset += length;
remaining -= length;
}
}

public void logon(final boolean resetSeqNumFlag)
{
logon(resetSeqNumFlag, 30);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import org.hamcrest.Matchers;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import uk.co.real_logic.artio.Constants;
import uk.co.real_logic.artio.Reply;
import uk.co.real_logic.artio.Side;
import uk.co.real_logic.artio.Timing;
import uk.co.real_logic.artio.*;
import uk.co.real_logic.artio.builder.*;
import uk.co.real_logic.artio.decoder.*;
import uk.co.real_logic.artio.engine.SessionInfo;
Expand Down Expand Up @@ -53,6 +50,7 @@
import static uk.co.real_logic.artio.SessionRejectReason.COMPID_PROBLEM;
import static uk.co.real_logic.artio.TestFixtures.cleanupMediaDriver;
import static uk.co.real_logic.artio.dictionary.SessionConstants.*;
import static uk.co.real_logic.artio.engine.EngineConfiguration.DEFAULT_RECEIVER_BUFFER_SIZE;
import static uk.co.real_logic.artio.engine.logger.Replayer.MOST_RECENT_MESSAGE;
import static uk.co.real_logic.artio.messages.InitialAcceptedSessionOwner.ENGINE;
import static uk.co.real_logic.artio.messages.InitialAcceptedSessionOwner.SOLE_LIBRARY;
Expand Down Expand Up @@ -905,6 +903,27 @@ public void shouldSupportResendRequestsAfterOfflineSequenceReset() throws Except
}
}

@Test(timeout = TEST_TIMEOUT_IN_MS)
public void shouldDisconnectConnectionTryingToSendOversizedMessage() throws IOException
{
setup(true, true);

setupLibrary();

try (FixConnection connection = FixConnection.initiate(port))
{
logon(connection);
final Session session = acquireSession();

connection.sendBytesLarge(TestFixtures.largeMessage(DEFAULT_RECEIVER_BUFFER_SIZE + 5));

assertSessionDisconnected(testSystem, session);
assertEquals(1, session.lastReceivedMsgSeqNum());

assertConnectionDisconnects(testSystem, connection);
}
}

private void assertSell(final ExecutionReportDecoder executionReport)
{
assertEquals(executionReport.toString(), Side.SELL, executionReport.sideAsEnum());
Expand Down

0 comments on commit d370dac

Please sign in to comment.