From 65b19356aefdc1a9dae8f07b13261b8094b32455 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?I=C3=B1aki=20Baz=20Castillo?= <ibc@aliax.net>
Date: Mon, 22 Apr 2024 17:40:55 +0200
Subject: [PATCH] Huge changes: make all methods in RtpPacket remove payload
 padding if present

---
 worker/src/RTC/RtpPacket.cpp          | 65 ++++++++++++++++++++-------
 worker/test/src/RTC/TestRtpPacket.cpp | 49 +++++++++++---------
 2 files changed, 76 insertions(+), 38 deletions(-)

diff --git a/worker/src/RTC/RtpPacket.cpp b/worker/src/RTC/RtpPacket.cpp
index 379bd323d0..bd036418b1 100644
--- a/worker/src/RTC/RtpPacket.cpp
+++ b/worker/src/RTC/RtpPacket.cpp
@@ -96,6 +96,7 @@ namespace RTC
 			}
 
 			payloadPadding = data[len - 1];
+
 			if (payloadPadding == 0)
 			{
 				MS_WARN_TAG(rtp, "padding byte cannot be 0, packet discarded");
@@ -112,6 +113,7 @@ namespace RTC
 
 				return nullptr;
 			}
+
 			payloadLength -= size_t{ payloadPadding };
 		}
 
@@ -371,6 +373,9 @@ namespace RTC
 		                        : flatbuffers::nullopt);
 	}
 
+	/**
+	 * NOTE: This method automatically removes payload padding if present.
+	 */
 	void RtpPacket::SetExtensions(uint8_t type, const std::vector<GenericExtension>& extensions)
 	{
 		MS_ASSERT(type == 1u || type == 2u, "type must be 1 or 2");
@@ -462,7 +467,7 @@ namespace RTC
 		if (this->headerExtension && shift != 0)
 		{
 			// Shift the payload.
-			std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding);
+			std::memmove(this->payload + shift, this->payload, this->payloadLength);
 			this->payload += shift;
 
 			// Update packet total size.
@@ -480,7 +485,7 @@ namespace RTC
 			this->headerExtension = reinterpret_cast<HeaderExtension*>(this->payload);
 
 			// Shift the payload.
-			std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding);
+			std::memmove(this->payload + shift, this->payload, this->payloadLength);
 			this->payload += shift;
 
 			// Update packet total size.
@@ -547,6 +552,15 @@ namespace RTC
 		}
 
 		MS_ASSERT(ptr == this->payload, "wrong ptr calculation");
+
+		// Remove padding if present.
+		if (this->payloadPadding != 0u)
+		{
+			SetPayloadPaddingFlag(false);
+
+			this->size -= size_t{ this->payloadPadding };
+			this->payloadPadding = 0u;
+		}
 	}
 
 	void RtpPacket::UpdateMid(const std::string& mid)
@@ -651,22 +665,24 @@ namespace RTC
 	}
 
 	/**
-	 * NOTE: This method automatically adjusts padding for padding+payload to be
-	 * padded to 4 bytes.
+	 * NOTE: This method automatically removes payload padding if present.
 	 */
 	void RtpPacket::SetPayloadLength(size_t length)
 	{
 		MS_TRACE();
 
-		auto payloadPadding = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(length)) - length;
-
 		this->size -= this->payloadLength;
-		this->size -= size_t{ this->payloadPadding };
-		this->payloadLength  = length;
-		this->payloadPadding = payloadPadding;
-		this->size += this->payloadLength + this->payloadPadding;
+		this->payloadLength = length;
+		this->size += this->payloadLength;
+
+		// Remove padding if present.
+		if (this->payloadPadding != 0u)
+		{
+			SetPayloadPaddingFlag(false);
 
-		SetPayloadPaddingFlag(this->payloadPadding > 0);
+			this->size -= size_t{ this->payloadPadding };
+			this->payloadPadding = 0u;
+		}
 	}
 
 	RtpPacket* RtpPacket::Clone() const
@@ -752,8 +768,12 @@ namespace RTC
 		return packet;
 	}
 
-	// NOTE: The caller must ensure that the buffer/memmory of the packet has
-	// space enough for adding 2 extra bytes.
+	/**
+	 * NOTE: The caller must ensure that the buffer/memmory of the packet has
+	 * space enough for adding 2 extra bytes.
+	 *
+	 * NOTE: This method automatically removes payload padding if present.
+	 */
 	void RtpPacket::RtxEncode(uint8_t payloadType, uint32_t ssrc, uint16_t seq)
 	{
 		MS_TRACE();
@@ -787,6 +807,9 @@ namespace RTC
 		}
 	}
 
+	/**
+	 * NOTE: This method automatically removes payload padding if present.
+	 */
 	bool RtpPacket::RtxDecode(uint8_t payloadType, uint32_t ssrc)
 	{
 		MS_TRACE();
@@ -855,8 +878,7 @@ namespace RTC
 	/**
 	 * Shifts the payload given offset (to right or to left).
 	 *
-	 * NOTE: This method doesn't automatically adjust padding for padding+payload
-	 * to be padded to 4 bytes.
+	 * NOTE: This method automatically removes payload padding if present.
 	 */
 	void RtpPacket::ShiftPayload(size_t payloadOffset, size_t shift, bool expand)
 	{
@@ -879,7 +901,7 @@ namespace RTC
 
 		if (expand)
 		{
-			shiftedLen = this->payloadLength + size_t{ this->payloadPadding } - payloadOffset;
+			shiftedLen = this->payloadLength - payloadOffset;
 
 			std::memmove(payloadOffsetPtr + shift, payloadOffsetPtr, shiftedLen);
 
@@ -888,13 +910,22 @@ namespace RTC
 		}
 		else
 		{
-			shiftedLen = this->payloadLength + size_t{ this->payloadPadding } - payloadOffset - shift;
+			shiftedLen = this->payloadLength - payloadOffset - shift;
 
 			std::memmove(payloadOffsetPtr, payloadOffsetPtr + shift, shiftedLen);
 
 			this->payloadLength -= shift;
 			this->size -= shift;
 		}
+
+		// Remove padding if present.
+		if (this->payloadPadding != 0u)
+		{
+			SetPayloadPaddingFlag(false);
+
+			this->size -= size_t{ this->payloadPadding };
+			this->payloadPadding = 0u;
+		}
 	}
 
 	void RtpPacket::ParseExtensions()
diff --git a/worker/test/src/RTC/TestRtpPacket.cpp b/worker/test/src/RTC/TestRtpPacket.cpp
index d5ec7279cf..dc8d54737d 100644
--- a/worker/test/src/RTC/TestRtpPacket.cpp
+++ b/worker/test/src/RTC/TestRtpPacket.cpp
@@ -435,6 +435,7 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 		REQUIRE(packet->HasTwoBytesExtensions() == false);
 		REQUIRE(packet->GetPayloadLength() == 8);
 		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4);
 		REQUIRE(packet->GetSize() == 40);
 
 		auto* payload = packet->GetPayload();
@@ -448,11 +449,12 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 		REQUIRE(payload[6] == 0x06);
 		REQUIRE(payload[7] == 0x07);
 
+		// NOTE: This will remove padding.
 		packet->ShiftPayload(0, 2, true);
 
 		REQUIRE(packet->GetPayloadLength() == 10);
-		REQUIRE(packet->GetPayloadPadding() == 4);
-		REQUIRE(packet->GetSize() == 42);
+		REQUIRE(packet->GetPayloadPadding() == 0);
+		REQUIRE(packet->GetSize() == 38);
 		REQUIRE(payload[2] == 0x00);
 		REQUIRE(payload[3] == 0x01);
 		REQUIRE(payload[4] == 0x02);
@@ -465,8 +467,8 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 		packet->ShiftPayload(0, 2, false);
 
 		REQUIRE(packet->GetPayloadLength() == 8);
-		REQUIRE(packet->GetPayloadPadding() == 4);
-		REQUIRE(packet->GetSize() == 40);
+		REQUIRE(packet->GetPayloadPadding() == 0);
+		REQUIRE(packet->GetSize() == 36);
 		REQUIRE(payload[0] == 0x00);
 		REQUIRE(payload[1] == 0x01);
 		REQUIRE(payload[2] == 0x02);
@@ -476,18 +478,18 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 		REQUIRE(payload[6] == 0x06);
 		REQUIRE(payload[7] == 0x07);
 
-		// NOTE: This will require padding to 2 bytes.
+		// NOTE: This will remove padding.
 		packet->SetPayloadLength(14);
 
 		REQUIRE(packet->GetPayloadLength() == 14);
-		REQUIRE(packet->GetPayloadPadding() == 2);
-		REQUIRE(packet->GetSize() == 44);
+		REQUIRE(packet->GetPayloadPadding() == 0);
+		REQUIRE(packet->GetSize() == 42);
 
 		packet->ShiftPayload(4, 4, true);
 
 		REQUIRE(packet->GetPayloadLength() == 18);
-		REQUIRE(packet->GetPayloadPadding() == 2);
-		REQUIRE(packet->GetSize() == 48);
+		REQUIRE(packet->GetPayloadPadding() == 0);
+		REQUIRE(packet->GetSize() == 46);
 		REQUIRE(payload[0] == 0x00);
 		REQUIRE(payload[1] == 0x01);
 		REQUIRE(payload[2] == 0x02);
@@ -545,21 +547,23 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 		REQUIRE(packet->HasTwoBytesExtensions() == false);
 		REQUIRE(packet->GetPayloadLength() == 12);
 		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 
 		extensions.clear();
 
+		// NOTE: This will remove padding.
 		packet->SetExtensions(1, extensions);
 
-		REQUIRE(packet->GetSize() == 32);
+		REQUIRE(packet->GetSize() == 28);
 		REQUIRE(packet->HasHeaderExtension() == true);
 		REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE);
 		REQUIRE(packet->GetHeaderExtensionLength() == 0);
 		REQUIRE(packet->HasOneByteExtensions() == true);
 		REQUIRE(packet->HasTwoBytesExtensions() == false);
 		REQUIRE(packet->GetPayloadLength() == 12);
-		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayloadPadding() == 0);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 
@@ -604,14 +608,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 
 		packet->SetExtensions(1, extensions);
 
-		REQUIRE(packet->GetSize() == 52); // 49 + 3 bytes for padding in header extension.
+		REQUIRE(packet->GetSize() == 48); // 49 + 3 bytes for padding in header extension.
 		REQUIRE(packet->HasHeaderExtension() == true);
 		REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE);
 		REQUIRE(packet->GetHeaderExtensionLength() == 20); // 17 + 3 bytes for padding.
 		REQUIRE(packet->HasOneByteExtensions() == true);
 		REQUIRE(packet->HasTwoBytesExtensions() == false);
 		REQUIRE(packet->GetPayloadLength() == 12);
-		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayloadPadding() == 0);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 		REQUIRE(packet->GetExtension(0, extenLen) == nullptr);
@@ -639,14 +643,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 
 		packet->SetExtensions(1, extensions);
 
-		REQUIRE(packet->GetSize() == 40);
+		REQUIRE(packet->GetSize() == 36);
 		REQUIRE(packet->HasHeaderExtension() == true);
 		REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE);
 		REQUIRE(packet->GetHeaderExtensionLength() == 8); // 5 + 3 bytes for padding.
 		REQUIRE(packet->HasOneByteExtensions() == true);
 		REQUIRE(packet->HasTwoBytesExtensions() == false);
 		REQUIRE(packet->GetPayloadLength() == 12);
-		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayloadPadding() == 0);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 		REQUIRE(packet->GetExtension(1, extenLen) == nullptr);
@@ -710,21 +714,23 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 		REQUIRE(packet->HasTwoBytesExtensions() == false);
 		REQUIRE(packet->GetPayloadLength() == 12);
 		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 
 		extensions.clear();
 
+		// NOTE: This will remove padding.
 		packet->SetExtensions(2, extensions);
 
-		REQUIRE(packet->GetSize() == 32);
+		REQUIRE(packet->GetSize() == 28);
 		REQUIRE(packet->HasHeaderExtension() == true);
 		REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000);
 		REQUIRE(packet->GetHeaderExtensionLength() == 0);
 		REQUIRE(packet->HasOneByteExtensions() == false);
 		REQUIRE(packet->HasTwoBytesExtensions() == true);
 		REQUIRE(packet->GetPayloadLength() == 12);
-		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayloadPadding() == 0);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 
@@ -753,16 +759,17 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 		  value2 // value
 		);
 
+		// NOTE: This will remove padding.
 		packet->SetExtensions(2, extensions);
 
-		REQUIRE(packet->GetSize() == 52); // 51 + 1 byte for padding in header extension.
+		REQUIRE(packet->GetSize() == 48); // 51 + 1 byte for padding in header extension.
 		REQUIRE(packet->HasHeaderExtension() == true);
 		REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000);
 		REQUIRE(packet->GetHeaderExtensionLength() == 20); // 19 + 1 byte for padding.
 		REQUIRE(packet->HasOneByteExtensions() == false);
 		REQUIRE(packet->HasTwoBytesExtensions() == true);
 		REQUIRE(packet->GetPayloadLength() == 12);
-		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayloadPadding() == 0);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 		REQUIRE(packet->GetExtension(0, extenLen) == nullptr);
@@ -798,14 +805,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
 
 		packet->SetExtensions(2, extensions);
 
-		REQUIRE(packet->GetSize() == 40);
+		REQUIRE(packet->GetSize() == 36);
 		REQUIRE(packet->HasHeaderExtension() == true);
 		REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000);
 		REQUIRE(packet->GetHeaderExtensionLength() == 8);
 		REQUIRE(packet->HasOneByteExtensions() == false);
 		REQUIRE(packet->HasTwoBytesExtensions() == true);
 		REQUIRE(packet->GetPayloadLength() == 12);
-		REQUIRE(packet->GetPayloadPadding() == 4);
+		REQUIRE(packet->GetPayloadPadding() == 0);
 		REQUIRE(packet->GetPayload()[0] == 0x11);
 		REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
 		REQUIRE(packet->GetExtension(1, extenLen) == nullptr);