Skip to content

Commit

Permalink
Huge changes: make all methods in RtpPacket remove payload padding if…
Browse files Browse the repository at this point in the history
… present
  • Loading branch information
ibc committed Apr 22, 2024
1 parent a703ca9 commit 65b1935
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 38 deletions.
65 changes: 48 additions & 17 deletions worker/src/RTC/RtpPacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ namespace RTC
}

payloadPadding = data[len - 1];

if (payloadPadding == 0)
{
MS_WARN_TAG(rtp, "padding byte cannot be 0, packet discarded");
Expand All @@ -112,6 +113,7 @@ namespace RTC

return nullptr;
}

payloadLength -= size_t{ payloadPadding };
}

Expand Down Expand Up @@ -371,6 +373,9 @@ namespace RTC
: flatbuffers::nullopt);
}

/**
* NOTE: This method automatically removes payload padding if present.
*/
void RtpPacket::SetExtensions(uint8_t type, const std::vector<GenericExtension>& extensions)
{
MS_ASSERT(type == 1u || type == 2u, "type must be 1 or 2");
Expand Down Expand Up @@ -462,7 +467,7 @@ namespace RTC
if (this->headerExtension && shift != 0)
{
// Shift the payload.
std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding);
std::memmove(this->payload + shift, this->payload, this->payloadLength);
this->payload += shift;

// Update packet total size.
Expand All @@ -480,7 +485,7 @@ namespace RTC
this->headerExtension = reinterpret_cast<HeaderExtension*>(this->payload);

// Shift the payload.
std::memmove(this->payload + shift, this->payload, this->payloadLength + this->payloadPadding);
std::memmove(this->payload + shift, this->payload, this->payloadLength);
this->payload += shift;

// Update packet total size.
Expand Down Expand Up @@ -547,6 +552,15 @@ namespace RTC
}

MS_ASSERT(ptr == this->payload, "wrong ptr calculation");

// Remove padding if present.
if (this->payloadPadding != 0u)
{
SetPayloadPaddingFlag(false);

this->size -= size_t{ this->payloadPadding };
this->payloadPadding = 0u;
}
}

void RtpPacket::UpdateMid(const std::string& mid)
Expand Down Expand Up @@ -651,22 +665,24 @@ namespace RTC
}

/**
* NOTE: This method automatically adjusts padding for padding+payload to be
* padded to 4 bytes.
* NOTE: This method automatically removes payload padding if present.
*/
void RtpPacket::SetPayloadLength(size_t length)
{
MS_TRACE();

auto payloadPadding = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(length)) - length;

this->size -= this->payloadLength;
this->size -= size_t{ this->payloadPadding };
this->payloadLength = length;
this->payloadPadding = payloadPadding;
this->size += this->payloadLength + this->payloadPadding;
this->payloadLength = length;
this->size += this->payloadLength;

// Remove padding if present.
if (this->payloadPadding != 0u)
{
SetPayloadPaddingFlag(false);

SetPayloadPaddingFlag(this->payloadPadding > 0);
this->size -= size_t{ this->payloadPadding };
this->payloadPadding = 0u;
}
}

RtpPacket* RtpPacket::Clone() const
Expand Down Expand Up @@ -752,8 +768,12 @@ namespace RTC
return packet;
}

// NOTE: The caller must ensure that the buffer/memmory of the packet has
// space enough for adding 2 extra bytes.
/**
* NOTE: The caller must ensure that the buffer/memmory of the packet has
* space enough for adding 2 extra bytes.
*
* NOTE: This method automatically removes payload padding if present.
*/
void RtpPacket::RtxEncode(uint8_t payloadType, uint32_t ssrc, uint16_t seq)
{
MS_TRACE();
Expand Down Expand Up @@ -787,6 +807,9 @@ namespace RTC
}
}

/**
* NOTE: This method automatically removes payload padding if present.
*/
bool RtpPacket::RtxDecode(uint8_t payloadType, uint32_t ssrc)
{
MS_TRACE();
Expand Down Expand Up @@ -855,8 +878,7 @@ namespace RTC
/**
* Shifts the payload given offset (to right or to left).
*
* NOTE: This method doesn't automatically adjust padding for padding+payload
* to be padded to 4 bytes.
* NOTE: This method automatically removes payload padding if present.
*/
void RtpPacket::ShiftPayload(size_t payloadOffset, size_t shift, bool expand)
{
Expand All @@ -879,7 +901,7 @@ namespace RTC

if (expand)
{
shiftedLen = this->payloadLength + size_t{ this->payloadPadding } - payloadOffset;
shiftedLen = this->payloadLength - payloadOffset;

std::memmove(payloadOffsetPtr + shift, payloadOffsetPtr, shiftedLen);

Expand All @@ -888,13 +910,22 @@ namespace RTC
}
else
{
shiftedLen = this->payloadLength + size_t{ this->payloadPadding } - payloadOffset - shift;
shiftedLen = this->payloadLength - payloadOffset - shift;

std::memmove(payloadOffsetPtr, payloadOffsetPtr + shift, shiftedLen);

this->payloadLength -= shift;
this->size -= shift;
}

// Remove padding if present.
if (this->payloadPadding != 0u)
{
SetPayloadPaddingFlag(false);

this->size -= size_t{ this->payloadPadding };
this->payloadPadding = 0u;
}
}

void RtpPacket::ParseExtensions()
Expand Down
49 changes: 28 additions & 21 deletions worker/test/src/RTC/TestRtpPacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE(packet->HasTwoBytesExtensions() == false);
REQUIRE(packet->GetPayloadLength() == 8);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4);
REQUIRE(packet->GetSize() == 40);

auto* payload = packet->GetPayload();
Expand All @@ -448,11 +449,12 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE(payload[6] == 0x06);
REQUIRE(payload[7] == 0x07);

// NOTE: This will remove padding.
packet->ShiftPayload(0, 2, true);

REQUIRE(packet->GetPayloadLength() == 10);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetSize() == 42);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetSize() == 38);
REQUIRE(payload[2] == 0x00);
REQUIRE(payload[3] == 0x01);
REQUIRE(payload[4] == 0x02);
Expand All @@ -465,8 +467,8 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
packet->ShiftPayload(0, 2, false);

REQUIRE(packet->GetPayloadLength() == 8);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetSize() == 40);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetSize() == 36);
REQUIRE(payload[0] == 0x00);
REQUIRE(payload[1] == 0x01);
REQUIRE(payload[2] == 0x02);
Expand All @@ -476,18 +478,18 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE(payload[6] == 0x06);
REQUIRE(payload[7] == 0x07);

// NOTE: This will require padding to 2 bytes.
// NOTE: This will remove padding.
packet->SetPayloadLength(14);

REQUIRE(packet->GetPayloadLength() == 14);
REQUIRE(packet->GetPayloadPadding() == 2);
REQUIRE(packet->GetSize() == 44);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetSize() == 42);

packet->ShiftPayload(4, 4, true);

REQUIRE(packet->GetPayloadLength() == 18);
REQUIRE(packet->GetPayloadPadding() == 2);
REQUIRE(packet->GetSize() == 48);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetSize() == 46);
REQUIRE(payload[0] == 0x00);
REQUIRE(payload[1] == 0x01);
REQUIRE(payload[2] == 0x02);
Expand Down Expand Up @@ -545,21 +547,23 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE(packet->HasTwoBytesExtensions() == false);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);

extensions.clear();

// NOTE: This will remove padding.
packet->SetExtensions(1, extensions);

REQUIRE(packet->GetSize() == 32);
REQUIRE(packet->GetSize() == 28);
REQUIRE(packet->HasHeaderExtension() == true);
REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE);
REQUIRE(packet->GetHeaderExtensionLength() == 0);
REQUIRE(packet->HasOneByteExtensions() == true);
REQUIRE(packet->HasTwoBytesExtensions() == false);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);

Expand Down Expand Up @@ -604,14 +608,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]")

packet->SetExtensions(1, extensions);

REQUIRE(packet->GetSize() == 52); // 49 + 3 bytes for padding in header extension.
REQUIRE(packet->GetSize() == 48); // 49 + 3 bytes for padding in header extension.
REQUIRE(packet->HasHeaderExtension() == true);
REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE);
REQUIRE(packet->GetHeaderExtensionLength() == 20); // 17 + 3 bytes for padding.
REQUIRE(packet->HasOneByteExtensions() == true);
REQUIRE(packet->HasTwoBytesExtensions() == false);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
REQUIRE(packet->GetExtension(0, extenLen) == nullptr);
Expand Down Expand Up @@ -639,14 +643,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]")

packet->SetExtensions(1, extensions);

REQUIRE(packet->GetSize() == 40);
REQUIRE(packet->GetSize() == 36);
REQUIRE(packet->HasHeaderExtension() == true);
REQUIRE(packet->GetHeaderExtensionId() == 0xBEDE);
REQUIRE(packet->GetHeaderExtensionLength() == 8); // 5 + 3 bytes for padding.
REQUIRE(packet->HasOneByteExtensions() == true);
REQUIRE(packet->HasTwoBytesExtensions() == false);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
REQUIRE(packet->GetExtension(1, extenLen) == nullptr);
Expand Down Expand Up @@ -710,21 +714,23 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
REQUIRE(packet->HasTwoBytesExtensions() == false);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() + packet->GetPayloadPadding() - 1] == 4);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);

extensions.clear();

// NOTE: This will remove padding.
packet->SetExtensions(2, extensions);

REQUIRE(packet->GetSize() == 32);
REQUIRE(packet->GetSize() == 28);
REQUIRE(packet->HasHeaderExtension() == true);
REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000);
REQUIRE(packet->GetHeaderExtensionLength() == 0);
REQUIRE(packet->HasOneByteExtensions() == false);
REQUIRE(packet->HasTwoBytesExtensions() == true);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);

Expand Down Expand Up @@ -753,16 +759,17 @@ SCENARIO("parse RTP packets", "[parser][rtp]")
value2 // value
);

// NOTE: This will remove padding.
packet->SetExtensions(2, extensions);

REQUIRE(packet->GetSize() == 52); // 51 + 1 byte for padding in header extension.
REQUIRE(packet->GetSize() == 48); // 51 + 1 byte for padding in header extension.
REQUIRE(packet->HasHeaderExtension() == true);
REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000);
REQUIRE(packet->GetHeaderExtensionLength() == 20); // 19 + 1 byte for padding.
REQUIRE(packet->HasOneByteExtensions() == false);
REQUIRE(packet->HasTwoBytesExtensions() == true);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
REQUIRE(packet->GetExtension(0, extenLen) == nullptr);
Expand Down Expand Up @@ -798,14 +805,14 @@ SCENARIO("parse RTP packets", "[parser][rtp]")

packet->SetExtensions(2, extensions);

REQUIRE(packet->GetSize() == 40);
REQUIRE(packet->GetSize() == 36);
REQUIRE(packet->HasHeaderExtension() == true);
REQUIRE(packet->GetHeaderExtensionId() == 0b0001000000000000);
REQUIRE(packet->GetHeaderExtensionLength() == 8);
REQUIRE(packet->HasOneByteExtensions() == false);
REQUIRE(packet->HasTwoBytesExtensions() == true);
REQUIRE(packet->GetPayloadLength() == 12);
REQUIRE(packet->GetPayloadPadding() == 4);
REQUIRE(packet->GetPayloadPadding() == 0);
REQUIRE(packet->GetPayload()[0] == 0x11);
REQUIRE(packet->GetPayload()[packet->GetPayloadLength() - 1] == 0xCC);
REQUIRE(packet->GetExtension(1, extenLen) == nullptr);
Expand Down

0 comments on commit 65b1935

Please sign in to comment.