diff --git a/artio-codecs/src/test/java/uk/co/real_logic/artio/util/TestMessages.java b/artio-codecs/src/test/java/uk/co/real_logic/artio/util/TestMessages.java index 7bddab5fba..7e3d8777e9 100644 --- a/artio-codecs/src/test/java/uk/co/real_logic/artio/util/TestMessages.java +++ b/artio-codecs/src/test/java/uk/co/real_logic/artio/util/TestMessages.java @@ -82,10 +82,6 @@ public final class TestMessages public static final byte[] INVALID_LENGTH_MESSAGE = toAscii( "8=FIX.4.4\0019=5\00135=A\00134=1\00149=TW\00152=20150604-12:46:54\00156=ISLD\00198=0\00110=000\001"); - public static final byte[] OVERSIZED_MESSAGE_START = - toAscii("8=FIX.4.2\0019=99999\00135=D\00134=4\00149=ABC_DEFG01\001" + - "52=20090323-15:40:29\00156=CCG\001115=XYZ\00111=NF 0542/03232009\00154=1\00138=100\00155=CVS\00140=1"); - public static final byte[] ZERO_CHECKSUM_MESSAGE = toAscii( "8=FIX.4.4\0019=0067\00135=0\00149=acceptor\00156=initiator\00134=2" + "\00152=20160415-12:50:23.294\001112=hi\00110=000\001"); diff --git a/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java b/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java index 0adfc26bd6..9742caf49c 100644 --- a/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java +++ b/artio-core/src/main/java/uk/co/real_logic/artio/engine/framer/FixReceiverEndPoint.java @@ -392,6 +392,7 @@ boolean retryFrameMessages() // true - no more framed messages in the buffer data to process. This could mean no more messages, or some data // that is an incomplete message. // false - needs to be retried, aka back-pressured + @SuppressWarnings("MethodLength") private boolean frameMessages(final long readTimestampInNs) { final MutableAsciiBuffer buffer = this.buffer; @@ -407,6 +408,14 @@ private boolean frameMessages(final long readTimestampInNs) try { + // 8=FIX.4.4|9=60|35=5|49=initiator|56=acceptor|34=3|52=20231220-13:12:16.021|10=234| + // ^ ^ ^ ^ ^ + // startOfBodyLength ---------+ | | | | + // endOfBodyLength -------------+ | | | + // startOfChecksumTag ----------------------------------------------------------------------+ | | + // endOfChecksumTag/startOfChecksumValue --------------------------------------------------------+ | + // endOfMessage -----------------------------------------------------------------------------------+ + final int startOfBodyLength = scanForBodyLength(offset, readTimestampInNs); if (startOfBodyLength < 0) { @@ -424,7 +433,10 @@ private boolean frameMessages(final long readTimestampInNs) final int endOfChecksumTag = startOfChecksumTag + MIN_CHECKSUM_SIZE; if (endOfChecksumTag >= usedBufferData) { - disconnectOnOversizedMessage(offset, readTimestampInNs); + if (isMessageOversized(offset)) + { + return saveOversizedMessageAndDisconnect(offset, readTimestampInNs); + } break; } @@ -442,6 +454,10 @@ private boolean frameMessages(final long readTimestampInNs) final int endOfMessage = scanEndOfMessage(startOfChecksumValue); if (endOfMessage == UNKNOWN_INDEX) { + if (isMessageOversized(offset)) + { + return saveOversizedMessageAndDisconnect(offset, readTimestampInNs); + } break; // Need more data } @@ -955,16 +971,27 @@ private boolean invalidateMessage(final int offset, final long readTimestamp) return saveInvalidMessage(offset, readTimestamp); } - private void disconnectOnOversizedMessage(final int offset, final long readTimestamp) + private boolean isMessageOversized(final int offset) + { + return offset == 0 && byteBuffer.remaining() == 0; + } + + // returns false if back-pressured + private boolean saveOversizedMessageAndDisconnect(final int offset, final long readTimestamp) { - if (offset == 0 && this.byteBuffer.remaining() == 0) + DebugLogger.log(FIX_MESSAGE, "Invalidated (oversized): ", buffer, offset, usedBufferData - offset); + + if (saveInvalidMessage(offset, readTimestamp)) { - saveInvalidMessage(offset, readTimestamp); - errorHandler.onError(new Exception(String.format( - "Unable to frame message, receiver buffer too small. connectionId=%d", - connectionId))); - disconnectEndpoint(DisconnectReason.EXCEPTION); + return false; } + + errorHandler.onError(new Exception( + "Unable to frame message, receiver buffer too small. connectionId=" + connectionId)); + + completeDisconnect(DisconnectReason.EXCEPTION); + + return true; } private boolean saveInvalidMessage(final int offset, final int length, final long readTimestamp) diff --git a/artio-core/src/test/java/uk/co/real_logic/artio/TestFixtures.java b/artio-core/src/test/java/uk/co/real_logic/artio/TestFixtures.java index 629c201e11..502488e645 100644 --- a/artio-core/src/test/java/uk/co/real_logic/artio/TestFixtures.java +++ b/artio-core/src/test/java/uk/co/real_logic/artio/TestFixtures.java @@ -22,6 +22,7 @@ import io.aeron.driver.MediaDriver; import org.agrona.IoUtil; import org.agrona.concurrent.YieldingIdleStrategy; +import uk.co.real_logic.artio.util.MutableAsciiBuffer; import java.io.File; import java.io.IOException; @@ -33,6 +34,7 @@ import java.util.Arrays; import static io.aeron.driver.ThreadingMode.SHARED; +import static org.junit.jupiter.api.Assertions.assertEquals; public final class TestFixtures { @@ -185,4 +187,38 @@ public static String largeTestReqId() return new String(testReqIDChars); } + + public static byte[] largeMessage(final int messageLength) + { + // 8=FIX.4.4|9=00000|35=0|49=initiator|56=acceptor|34=2|52=20231220-13:12:16.020|112=...|10=...| + + if (messageLength < 91) + { + throw new IllegalArgumentException(messageLength + " is not large enough"); + } + final byte[] bytes = new byte[messageLength]; + final MutableAsciiBuffer buffer = new MutableAsciiBuffer(bytes); + int index = buffer.putAscii(0, "8=FIX.4.4\0019="); + final int bodyLength = messageLength - (18 + 7); // header + trailer + if (bodyLength > 99_999) + { + throw new IllegalArgumentException(messageLength + " is too large"); + } + buffer.putNaturalPaddedIntAscii(index, 5, bodyLength); + index += 5; + index += buffer.putAscii(index, + "\00135=0\00149=initiator\00156=acceptor\00134=2\00152=20231220-13:12:16.020\001112="); + while (index < messageLength - 8) + { + buffer.putCharAscii(index++, '+'); + } + buffer.putSeparator(index++); + final int checksum = buffer.computeChecksum(0, index); + index += buffer.putAscii(index, "10="); + buffer.putNaturalPaddedIntAscii(index, 3, checksum); + index += 3; + buffer.putSeparator(index++); + assertEquals(messageLength, index); + return bytes; + } } diff --git a/artio-core/src/test/java/uk/co/real_logic/artio/TestFixturesTest.java b/artio-core/src/test/java/uk/co/real_logic/artio/TestFixturesTest.java new file mode 100644 index 0000000000..4a28ace24c --- /dev/null +++ b/artio-core/src/test/java/uk/co/real_logic/artio/TestFixturesTest.java @@ -0,0 +1,26 @@ +package uk.co.real_logic.artio; + +import org.junit.jupiter.api.Test; +import uk.co.real_logic.artio.decoder.HeartbeatDecoder; +import uk.co.real_logic.artio.util.MutableAsciiBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class TestFixturesTest +{ + @Test + void shouldGenerateMessageOfGivenLength() + { + final int messageLength = 256; + final byte[] bytes = TestFixtures.largeMessage(messageLength); + assertEquals(messageLength, bytes.length); + + final HeartbeatDecoder decoder = new HeartbeatDecoder(); + final MutableAsciiBuffer buffer = new MutableAsciiBuffer(bytes); + decoder.decode(buffer, 0, bytes.length); + assertTrue(decoder.validate()); + assertEquals(buffer.computeChecksum(0, messageLength - 7), + Integer.parseInt(decoder.trailer().checkSumAsString(), 10)); + } +} diff --git a/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/ReceiverEndPointTest.java b/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/ReceiverEndPointTest.java index 84fd9d3932..344a80d0be 100644 --- a/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/ReceiverEndPointTest.java +++ b/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/ReceiverEndPointTest.java @@ -22,9 +22,12 @@ import org.agrona.concurrent.status.AtomicCounter; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.InOrder; import org.mockito.Mockito; import org.mockito.verification.VerificationMode; +import uk.co.real_logic.artio.TestFixtures; import uk.co.real_logic.artio.decoder.LogonDecoder; import uk.co.real_logic.artio.dictionary.FixDictionary; import uk.co.real_logic.artio.engine.FixEngine; @@ -41,15 +44,14 @@ import java.nio.channels.ClosedChannelException; import java.util.HashMap; import java.util.function.ToIntFunction; +import java.util.stream.IntStream; import static io.aeron.Publication.BACK_PRESSURED; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; import static uk.co.real_logic.artio.dictionary.ExampleDictionary.TAG_SPECIFIED_OUT_OF_REQUIRED_ORDER_MESSAGE_BYTES; import static uk.co.real_logic.artio.engine.EngineConfiguration.NO_THROTTLE_WINDOW; -import static uk.co.real_logic.artio.messages.DisconnectReason.DUPLICATE_SESSION; -import static uk.co.real_logic.artio.messages.DisconnectReason.REMOTE_DISCONNECT; -import static uk.co.real_logic.artio.messages.DisconnectReason.EXCEPTION; +import static uk.co.real_logic.artio.messages.DisconnectReason.*; import static uk.co.real_logic.artio.messages.MessageStatus.*; import static uk.co.real_logic.artio.session.Session.UNKNOWN; import static uk.co.real_logic.artio.util.TestMessages.*; @@ -206,10 +208,16 @@ void shouldFrameValidFixMessage() sessionReceivesOneMessage(); } - @Test - void shouldDetectOversizedFixMessage() + static IntStream overflowRange() + { + return IntStream.range(1, 10); + } + + @ParameterizedTest + @MethodSource("overflowRange") + void shouldDetectOversizedFixMessage(final int overflow) { - theEndpointReceivesTheStartOfAnOversizedMessage(); + theEndpointReceivesTheStartOfAnOversizedMessage(overflow); polls(BUFFER_SIZE); @@ -220,6 +228,23 @@ void shouldDetectOversizedFixMessage() sessionReceivesNoMessages(); } + @ParameterizedTest + @MethodSource("overflowRange") + void shouldHandleBackPressureWhenSavingOversizedMessage(final int overflow) + { + theEndpointReceivesTheStartOfAnOversizedMessage(overflow); + + firstSaveAttemptIsBackPressured(); + polls(-BUFFER_SIZE); + assertTrue(endPoint.retryFrameMessages()); + + savesInvalidMessage(BUFFER_SIZE, times(2), INVALID, TIMESTAMP); + verifyError(times(1)); + verifyDisconnected(EXCEPTION); + + sessionReceivesNoMessages(); + } + @Test void shouldFrameValidFixMessageWhenBackpressuredSelectionKeyCase() { @@ -637,16 +662,16 @@ private void theEndpointReceivesACompleteMessage() theEndpointReceives(EG_MESSAGE, 0, MSG_LEN); } - private void theEndpointReceivesTheStartOfAnOversizedMessage() + private void theEndpointReceivesTheStartOfAnOversizedMessage(final int overflow) { endpointBufferUpdatedWith( (buffer) -> { - final int paddingLength = buffer.capacity() - OVERSIZED_MESSAGE_START.length; - assertTrue(paddingLength > 0); - buffer.put(OVERSIZED_MESSAGE_START, 0, OVERSIZED_MESSAGE_START.length); - buffer.put(new byte[paddingLength], 0, paddingLength); - return buffer.capacity(); + final int capacity = buffer.capacity(); + final int messageLength = capacity + overflow; + final byte[] bytes = TestFixtures.largeMessage(messageLength); + buffer.put(bytes, 0, capacity); + return capacity; }); } diff --git a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FixConnection.java b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FixConnection.java index 2d34aa7d9e..8ebdfa51af 100644 --- a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FixConnection.java +++ b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/FixConnection.java @@ -239,6 +239,20 @@ public void sendBytes(final byte[] bytes) send(0, length); } + public void sendBytesLarge(final byte[] bytes) + { + int offset = 0; + int remaining = bytes.length; + while (remaining > 0) + { + final int length = Math.min(remaining, BUFFER_SIZE); + writeAsciiBuffer.putBytes(0, bytes, offset, length); + send(0, length); + offset += length; + remaining -= length; + } + } + public void logon(final boolean resetSeqNumFlag) { logon(resetSeqNumFlag, 30); diff --git a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java index e5c3ebefcb..688492be63 100644 --- a/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java +++ b/artio-system-tests/src/test/java/uk/co/real_logic/artio/system_tests/MessageBasedAcceptorSystemTest.java @@ -19,10 +19,7 @@ import org.hamcrest.Matchers; import org.junit.Test; import org.mockito.ArgumentMatchers; -import uk.co.real_logic.artio.Constants; -import uk.co.real_logic.artio.Reply; -import uk.co.real_logic.artio.Side; -import uk.co.real_logic.artio.Timing; +import uk.co.real_logic.artio.*; import uk.co.real_logic.artio.builder.*; import uk.co.real_logic.artio.decoder.*; import uk.co.real_logic.artio.engine.SessionInfo; @@ -53,6 +50,7 @@ import static uk.co.real_logic.artio.SessionRejectReason.COMPID_PROBLEM; import static uk.co.real_logic.artio.TestFixtures.cleanupMediaDriver; import static uk.co.real_logic.artio.dictionary.SessionConstants.*; +import static uk.co.real_logic.artio.engine.EngineConfiguration.DEFAULT_RECEIVER_BUFFER_SIZE; import static uk.co.real_logic.artio.engine.logger.Replayer.MOST_RECENT_MESSAGE; import static uk.co.real_logic.artio.messages.InitialAcceptedSessionOwner.ENGINE; import static uk.co.real_logic.artio.messages.InitialAcceptedSessionOwner.SOLE_LIBRARY; @@ -905,6 +903,27 @@ public void shouldSupportResendRequestsAfterOfflineSequenceReset() throws Except } } + @Test(timeout = TEST_TIMEOUT_IN_MS) + public void shouldDisconnectConnectionTryingToSendOversizedMessage() throws IOException + { + setup(true, true); + + setupLibrary(); + + try (FixConnection connection = FixConnection.initiate(port)) + { + logon(connection); + final Session session = acquireSession(); + + connection.sendBytesLarge(TestFixtures.largeMessage(DEFAULT_RECEIVER_BUFFER_SIZE + 5)); + + assertSessionDisconnected(testSystem, session); + assertEquals(1, session.lastReceivedMsgSeqNum()); + + assertConnectionDisconnects(testSystem, connection); + } + } + private void assertSell(final ExecutionReportDecoder executionReport) { assertEquals(executionReport.toString(), Side.SELL, executionReport.sideAsEnum());