From 65b19356aefdc1a9dae8f07b13261b8094b32455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1aki=20Baz=20Castillo?= <ibc@aliax.net> Date: Mon, 22 Apr 2024 17:40:55 +0200 Subject: [PATCH] Huge changes: make all methods in RtpPacket remove payload padding if present --- worker/src/RTC/RtpPacket.cpp | 65 ++++++++++++++++++++------- worker/test/src/RTC/TestRtpPacket.cpp | 49 +++++++++++--------- 2 files changed, 76 insertions(+), 38 deletions(-) diff --git a/worker/src/RTC/RtpPacket.cpp b/worker/src/RTC/RtpPacket.cpp index 379bd323d0..bd036418b1 100644 --- a/worker/src/RTC/RtpPacket.cpp +++ b/worker/src/RTC/RtpPacket.cpp @@ -96,6 +96,7 @@ namespace RTC } payloadPadding = data[len - 1]; + if (payloadPadding == 0) { MS_WARN_TAG(rtp, "padding byte cannot be 0, packet discarded"); @@ -112,6 +113,7 @@ namespace RTC return nullptr; } + payloadLength -= size_t{ payloadPadding }; } @@ -371,6 +373,9 @@ namespace RTC : flatbuffers::nullopt); } + /** + * NOTE: This method automatically removes payload padding if present. + */ void RtpPacket::SetExtensions(uint8_t type, const std::vector<GenericExtension>& extensions) { MS_ASSERT(type == 1u || type == 2u, "type must be 1 or 2"); @@ -462,7 +467,7 @@ namespace RTC if (this->headerExtension && shift != 0) { // Shift the payload. - std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding); + std::memmove(this->payload + shift, this->payload, this->payloadLength); this->payload += shift; // Update packet total size. @@ -480,7 +485,7 @@ namespace RTC this->headerExtension = reinterpret_cast<HeaderExtension*>(this->payload); // Shift the payload. - std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding); + std::memmove(this->payload + shift, this->payload, this->payloadLength); this->payload += shift; // Update packet total size. @@ -547,6 +552,15 @@ namespace RTC } MS_ASSERT(ptr == this->payload, "wrong ptr calculation"); + + // Remove padding if present. + if (this->payloadPadding != 0u) + { + SetPayloadPaddingFlag(false); + + this->size -= size_t{ this->payloadPadding }; + this->payloadPadding = 0u; + } } void RtpPacket::UpdateMid(const std::string& mid) @@ -651,22 +665,24 @@ namespace RTC } /** - * NOTE: This method automatically adjusts padding for padding+payload to be - * padded to 4 bytes. + * NOTE: This method automatically removes payload padding if present. */ void RtpPacket::SetPayloadLength(size_t length) { MS_TRACE(); - auto payloadPadding = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(length)) - length; - this->size -= this->payloadLength; - this->size -= size_t{ this->payloadPadding }; - this->payloadLength = length; - this->payloadPadding = payloadPadding; - this->size += this->payloadLength + this->payloadPadding; + this->payloadLength = length; + this->size += this->payloadLength; + + // Remove padding if present. + if (this->payloadPadding != 0u) + { + SetPayloadPaddingFlag(false); - SetPayloadPaddingFlag(this->payloadPadding > 0); + this->size -= size_t{ this->payloadPadding }; + this->payloadPadding = 0u; + } } RtpPacket* RtpPacket::Clone() const @@ -752,8 +768,12 @@ namespace RTC return packet; } - // NOTE: The caller must ensure that the buffer/memmory of the packet has - // space enough for adding 2 extra bytes. + /** + * NOTE: The caller must ensure that the buffer/memmory of the packet has + * space enough for adding 2 extra bytes. + * + * NOTE: This method automatically removes payload padding if present. + */ void RtpPacket::RtxEncode(uint8_t payloadType, uint32_t ssrc, uint16_t seq) { MS_TRACE(); @@ -787,6 +807,9 @@ namespace RTC } } + /** + * NOTE: This method automatically removes payload padding if present. + */ bool RtpPacket::RtxDecode(uint8_t payloadType, uint32_t ssrc) { MS_TRACE(); @@ -855,8 +878,7 @@ namespace RTC /** * Shifts the payload given offset (to right or to left). * - * NOTE: This method doesn't automatically adjust padding for padding+payload - * to be padded to 4 bytes. + * NOTE: This method automatically removes payload padding if present. */ void RtpPacket::ShiftPayload(size_t payloadOffset, size_t shift, bool expand) { @@ -879,7 +901,7 @@ namespace RTC if (expand) { - shiftedLen = this->payloadLength + size_t{ this->payloadPadding } - payloadOffset; + shiftedLen = this->payloadLength - payloadOffset; std::memmove(payloadOffsetPtr + shift, payloadOffsetPtr, shiftedLen); @@ -888,13 +910,22 @@ namespace RTC } else { - shiftedLen = this->payloadLength + size_t{ this->payloadPadding } - payloadOffset - shift; + shiftedLen = this->payloadLength - payloadOffset - shift; std::memmove(payloadOffsetPtr, payloadOffsetPtr + shift, shiftedLen); this->payloadLength -= shift; this->size -= shift; } + + // Remove padding if present. + if (this->payloadPadding != 0u) + { + SetPayloadPaddingFlag(false); + + this->size -= size_t{ this->payloadPadding }; + this->payloadPadding = 0u; + } } void RtpPacket::ParseExtensions() diff --git a/worker/test/src/RTC/TestRtpPacket.cpp b/worker/test/src/RTC/TestRtpPacket.cpp index d5ec7279cf..dc8d54737d 100644 --- a/worker/test/src/RTC/TestRtpPacket.cpp +++ b/worker/test/src/RTC/TestRtpPacket.cpp @@ -435,6 +435,7 @@ SCENARIO("parse RTP packets", "[parser][rtp]") REQUIRE(packet->HasTwoBytesExtensions() == false); REQUIRE(packet->GetPayloadLength() == 8); REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4); REQUIRE(packet->GetSize() == 40); auto* payload = packet->GetPayload(); @@ -448,11 +449,12 @@ SCENARIO("parse RTP packets", "[parser][rtp]") REQUIRE(payload[6] == 0x06); REQUIRE(payload[7] == 0x07); + // NOTE: This will remove padding. packet->ShiftPayload(0, 2, true); REQUIRE(packet->GetPayloadLength() == 10); - REQUIRE(packet->GetPayloadPadding() == 4); - REQUIRE(packet->GetSize() == 42); + REQUIRE(packet->GetPayloadPadding() == 0); + REQUIRE(packet->GetSize() == 38); REQUIRE(payload[2] == 0x00); REQUIRE(payload[3] == 0x01); REQUIRE(payload[4] == 0x02); @@ -465,8 +467,8 @@ SCENARIO("parse RTP packets", "[parser][rtp]") packet->ShiftPayload(0, 2, false); REQUIRE(packet->GetPayloadLength() == 8); - REQUIRE(packet->GetPayloadPadding() == 4); - REQUIRE(packet->GetSize() == 40); + REQUIRE(packet->GetPayloadPadding() == 0); + REQUIRE(packet->GetSize() == 36); REQUIRE(payload[0] == 0x00); REQUIRE(payload[1] == 0x01); REQUIRE(payload[2] == 0x02); @@ -476,18 +478,18 @@ SCENARIO("parse RTP packets", "[parser][rtp]") REQUIRE(payload[6] == 0x06); REQUIRE(payload[7] == 0x07); - // NOTE: This will require padding to 2 bytes. + // NOTE: This will remove padding. packet->SetPayloadLength(14); REQUIRE(packet->GetPayloadLength() == 14); - REQUIRE(packet->GetPayloadPadding() == 2); - REQUIRE(packet->GetSize() == 44); + REQUIRE(packet->GetPayloadPadding() == 0); + REQUIRE(packet->GetSize() == 42); packet->ShiftPayload(4, 4, true); REQUIRE(packet->GetPayloadLength() == 18); - REQUIRE(packet->GetPayloadPadding() == 2); - REQUIRE(packet->GetSize() == 48); + REQUIRE(packet->GetPayloadPadding() == 0); + REQUIRE(packet->GetSize() == 46); REQUIRE(payload[0] == 0x00); REQUIRE(payload[1] == 0x01); REQUIRE(payload[2] == 0x02); @@ -545,21 +547,23 @@ SCENARIO("parse RTP packets", "[parser][rtp]") REQUIRE(packet->HasTwoBytesExtensions() == false); REQUIRE(packet->GetPayloadLength() == 12); REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); extensions.clear(); + // NOTE: This will remove padding. packet->SetExtensions(1, extensions); - REQUIRE(packet->GetSize() == 32); + REQUIRE(packet->GetSize() == 28); REQUIRE(packet->HasHeaderExtension() == true); REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE); REQUIRE(packet->GetHeaderExtensionLength() == 0); REQUIRE(packet->HasOneByteExtensions() == true); REQUIRE(packet->HasTwoBytesExtensions() == false); REQUIRE(packet->GetPayloadLength() == 12); - REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayloadPadding() == 0); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); @@ -604,14 +608,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]") packet->SetExtensions(1, extensions); - REQUIRE(packet->GetSize() == 52); // 49 + 3 bytes for padding in header extension. + REQUIRE(packet->GetSize() == 48); // 49 + 3 bytes for padding in header extension. REQUIRE(packet->HasHeaderExtension() == true); REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE); REQUIRE(packet->GetHeaderExtensionLength() == 20); // 17 + 3 bytes for padding. REQUIRE(packet->HasOneByteExtensions() == true); REQUIRE(packet->HasTwoBytesExtensions() == false); REQUIRE(packet->GetPayloadLength() == 12); - REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayloadPadding() == 0); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); REQUIRE(packet->GetExtension(0, extenLen) == nullptr); @@ -639,14 +643,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]") packet->SetExtensions(1, extensions); - REQUIRE(packet->GetSize() == 40); + REQUIRE(packet->GetSize() == 36); REQUIRE(packet->HasHeaderExtension() == true); REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE); REQUIRE(packet->GetHeaderExtensionLength() == 8); // 5 + 3 bytes for padding. REQUIRE(packet->HasOneByteExtensions() == true); REQUIRE(packet->HasTwoBytesExtensions() == false); REQUIRE(packet->GetPayloadLength() == 12); - REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayloadPadding() == 0); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); REQUIRE(packet->GetExtension(1, extenLen) == nullptr); @@ -710,21 +714,23 @@ SCENARIO("parse RTP packets", "[parser][rtp]") REQUIRE(packet->HasTwoBytesExtensions() == false); REQUIRE(packet->GetPayloadLength() == 12); REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); extensions.clear(); + // NOTE: This will remove padding. packet->SetExtensions(2, extensions); - REQUIRE(packet->GetSize() == 32); + REQUIRE(packet->GetSize() == 28); REQUIRE(packet->HasHeaderExtension() == true); REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000); REQUIRE(packet->GetHeaderExtensionLength() == 0); REQUIRE(packet->HasOneByteExtensions() == false); REQUIRE(packet->HasTwoBytesExtensions() == true); REQUIRE(packet->GetPayloadLength() == 12); - REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayloadPadding() == 0); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); @@ -753,16 +759,17 @@ SCENARIO("parse RTP packets", "[parser][rtp]") value2 // value ); + // NOTE: This will remove padding. packet->SetExtensions(2, extensions); - REQUIRE(packet->GetSize() == 52); // 51 + 1 byte for padding in header extension. + REQUIRE(packet->GetSize() == 48); // 51 + 1 byte for padding in header extension. REQUIRE(packet->HasHeaderExtension() == true); REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000); REQUIRE(packet->GetHeaderExtensionLength() == 20); // 19 + 1 byte for padding. REQUIRE(packet->HasOneByteExtensions() == false); REQUIRE(packet->HasTwoBytesExtensions() == true); REQUIRE(packet->GetPayloadLength() == 12); - REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayloadPadding() == 0); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); REQUIRE(packet->GetExtension(0, extenLen) == nullptr); @@ -798,14 +805,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]") packet->SetExtensions(2, extensions); - REQUIRE(packet->GetSize() == 40); + REQUIRE(packet->GetSize() == 36); REQUIRE(packet->HasHeaderExtension() == true); REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000); REQUIRE(packet->GetHeaderExtensionLength() == 8); REQUIRE(packet->HasOneByteExtensions() == false); REQUIRE(packet->HasTwoBytesExtensions() == true); REQUIRE(packet->GetPayloadLength() == 12); - REQUIRE(packet->GetPayloadPadding() == 4); + REQUIRE(packet->GetPayloadPadding() == 0); REQUIRE(packet->GetPayload()[0] == 0x11); REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC); REQUIRE(packet->GetExtension(1, extenLen) == nullptr);