From 6e9fb3385ee6adea446bb56c0accf7beb0b9f86e Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 16 Dec 2024 13:35:09 -0800 Subject: [PATCH 1/2] Generalize logic to recreate colorspace conversion objects --- .../decoders/_core/VideoDecoder.cpp | 150 ++++++++++-------- src/torchcodec/decoders/_core/VideoDecoder.h | 20 ++- version.txt | 2 +- 3 files changed, 96 insertions(+), 76 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index ab3d3f9d..3a7927d1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -204,15 +204,16 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); } -bool VideoDecoder::SwsContextKey::operator==( - const VideoDecoder::SwsContextKey& other) { +bool VideoDecoder::DecodedFrameContext::operator==( + const VideoDecoder::DecodedFrameContext& other) { return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight && decodedFormat == other.decodedFormat && - outputWidth == other.outputWidth && outputHeight == other.outputHeight; + expectedWidth == other.expectedWidth && + expectedHeight == other.expectedHeight; } -bool VideoDecoder::SwsContextKey::operator!=( - const VideoDecoder::SwsContextKey& other) { +bool VideoDecoder::DecodedFrameContext::operator!=( + const VideoDecoder::DecodedFrameContext& other) { return !(*this == other); } @@ -313,17 +314,14 @@ std::unique_ptr VideoDecoder::createFromBuffer( return std::unique_ptr(new VideoDecoder(buffer, length)); } -void VideoDecoder::initializeFilterGraph( +void VideoDecoder::createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth) { FilterState& filterState = streamInfo.filterState; - if (filterState.filterGraph) { - return; - } - filterState.filterGraph.reset(avfilter_graph_alloc()); TORCH_CHECK(filterState.filterGraph.get() != nullptr); + if (streamInfo.options.ffmpegThreadCount.has_value()) { filterState.filterGraph->nb_threads = streamInfo.options.ffmpegThreadCount.value(); @@ -921,12 +919,32 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( torch::Tensor outputTensor; if (output.streamType == AVMEDIA_TYPE_VIDEO) { + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + enum AVPixelFormat frameFormat = + static_cast(frame->format); + auto frameContext = DecodedFrameContext{ + frame->width, + frame->height, + frameFormat, + expectedOutputWidth, + expectedOutputHeight}; + if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( expectedOutputHeight, expectedOutputWidth, torch::kCPU)); + if (!streamInfo.swsContext || streamInfo.prevFrame != frameContext) { + createSwsContext(streamInfo, frameContext, frame->colorspace); + streamInfo.prevFrame = frameContext; + } int resultHeight = - convertFrameToBufferUsingSwsScale(streamIndex, frame, outputTensor); + convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? @@ -941,16 +959,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - // Note that is a lazy init; we initialize filtergraph the first time - // we have a raw decoded frame. We do this lazily because up until this - // point, we really don't know what the resolution of the frames are - // without modification. In theory, we should be able to get that from the - // stream metadata, but in practice, we have encountered videos where the - // stream metadata had a different resolution from the actual resolution - // of the raw decoded frames. - if (!streamInfo.filterState.filterGraph) { - initializeFilterGraph( + if (!streamInfo.filterState.filterGraph || + streamInfo.prevFrame != frameContext) { + createFilterGraph( streamInfo, expectedOutputHeight, expectedOutputWidth); + streamInfo.prevFrame = frameContext; } outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); @@ -1351,7 +1364,53 @@ double VideoDecoder::getPtsSecondsForFrame( return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase); } -int VideoDecoder::convertFrameToBufferUsingSwsScale( +void VideoDecoder::createSwsContext( + StreamInfo& streamInfo, + const DecodedFrameContext& frameContext, + const enum AVColorSpace colorspace) { + SwsContext* swsContext = sws_getContext( + frameContext.decodedWidth, + frameContext.decodedHeight, + frameContext.decodedFormat, + frameContext.expectedWidth, + frameContext.expectedHeight, + AV_PIX_FMT_RGB24, + SWS_BILINEAR, + nullptr, + nullptr, + nullptr); + TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); + + int* invTable = nullptr; + int* table = nullptr; + int srcRange, dstRange, brightness, contrast, saturation; + int ret = sws_getColorspaceDetails( + swsContext, + &invTable, + &srcRange, + &table, + &dstRange, + &brightness, + &contrast, + &saturation); + TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); + + const int* colorspaceTable = sws_getCoefficients(colorspace); + ret = sws_setColorspaceDetails( + swsContext, + colorspaceTable, + srcRange, + colorspaceTable, + dstRange, + brightness, + contrast, + saturation); + TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); + + streamInfo.swsContext.reset(swsContext); +} + +int VideoDecoder::convertFrameToTensorUsingSwsScale( int streamIndex, const AVFrame* frame, torch::Tensor& outputTensor) { @@ -1361,50 +1420,6 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale( int expectedOutputHeight = outputTensor.sizes()[0]; int expectedOutputWidth = outputTensor.sizes()[1]; - auto curFrameSwsContextKey = SwsContextKey{ - frame->width, - frame->height, - frameFormat, - expectedOutputWidth, - expectedOutputHeight}; - if (activeStream.swsContext.get() == nullptr || - activeStream.swsContextKey != curFrameSwsContextKey) { - SwsContext* swsContext = sws_getContext( - frame->width, - frame->height, - frameFormat, - expectedOutputWidth, - expectedOutputHeight, - AV_PIX_FMT_RGB24, - SWS_BILINEAR, - nullptr, - nullptr, - nullptr); - int* invTable = nullptr; - int* table = nullptr; - int srcRange, dstRange, brightness, contrast, saturation; - sws_getColorspaceDetails( - swsContext, - &invTable, - &srcRange, - &table, - &dstRange, - &brightness, - &contrast, - &saturation); - const int* colorspaceTable = sws_getCoefficients(frame->colorspace); - sws_setColorspaceDetails( - swsContext, - colorspaceTable, - srcRange, - colorspaceTable, - dstRange, - brightness, - contrast, - saturation); - activeStream.swsContextKey = curFrameSwsContextKey; - activeStream.swsContext.reset(swsContext); - } SwsContext* swsContext = activeStream.swsContext.get(); uint8_t* pointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; @@ -1428,10 +1443,12 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); } + UniqueAVFrame filteredFrame(av_frame_alloc()); ffmpegStatus = av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get()); TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24); + auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get()); int height = frameDims.height; int width = frameDims.width; @@ -1441,9 +1458,8 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( auto deleter = [filteredFramePtr](void*) { UniqueAVFrame frameToDelete(filteredFramePtr); }; - torch::Tensor tensor = torch::from_blob( + return torch::from_blob( filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); - return tensor; } VideoDecoder::~VideoDecoder() { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 6ad2ab5e..0893d9a6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -313,14 +313,14 @@ class VideoDecoder { AVFilterContext* sourceContext = nullptr; AVFilterContext* sinkContext = nullptr; }; - struct SwsContextKey { + struct DecodedFrameContext { int decodedWidth; int decodedHeight; AVPixelFormat decodedFormat; - int outputWidth; - int outputHeight; - bool operator==(const SwsContextKey&); - bool operator!=(const SwsContextKey&); + int expectedWidth; + int expectedHeight; + bool operator==(const DecodedFrameContext&); + bool operator!=(const DecodedFrameContext&); }; // Stores information for each stream. struct StreamInfo { @@ -342,7 +342,7 @@ class VideoDecoder { ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; std::vector keyFrames; std::vector allFrames; - SwsContextKey swsContextKey; + DecodedFrameContext prevFrame; UniqueSwsContext swsContext; }; // Returns the key frame index of the presentation timestamp using FFMPEG's @@ -371,10 +371,14 @@ class VideoDecoder { void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex); // Creates and initializes a filter graph for a stream. The filter graph can // do rescaling and color conversion. - void initializeFilterGraph( + void createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth); + void createSwsContext( + StreamInfo& streamInfo, + const DecodedFrameContext& frameContext, + const enum AVColorSpace colorspace); void maybeSeekToBeforeDesiredPts(); RawDecodedOutput getDecodedOutputWithFilter( std::function); @@ -389,7 +393,7 @@ class VideoDecoder { torch::Tensor convertFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* frame); - int convertFrameToBufferUsingSwsScale( + int convertFrameToTensorUsingSwsScale( int streamIndex, const AVFrame* frame, torch::Tensor& outputTensor); diff --git a/version.txt b/version.txt index c730b7f4..d917d3e2 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.0.4a0 +0.1.2 From e8c548c2b991b606d53f660d48eb1583cb54b4e5 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 16 Dec 2024 13:44:15 -0800 Subject: [PATCH 2/2] Let's retain the "alpha" tag on the main repo. --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index d917d3e2..db5a1c4b 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.2 +0.1.2a0