Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize logic to recreate colorspace conversion objects #436

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 83 additions & 67 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,16 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
}

bool VideoDecoder::SwsContextKey::operator==(
const VideoDecoder::SwsContextKey& other) {
bool VideoDecoder::DecodedFrameContext::operator==(
const VideoDecoder::DecodedFrameContext& other) {
return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight &&
decodedFormat == other.decodedFormat &&
outputWidth == other.outputWidth && outputHeight == other.outputHeight;
expectedWidth == other.expectedWidth &&
expectedHeight == other.expectedHeight;
}

bool VideoDecoder::SwsContextKey::operator!=(
const VideoDecoder::SwsContextKey& other) {
bool VideoDecoder::DecodedFrameContext::operator!=(
const VideoDecoder::DecodedFrameContext& other) {
return !(*this == other);
}

Expand Down Expand Up @@ -313,17 +314,14 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
return std::unique_ptr<VideoDecoder>(new VideoDecoder(buffer, length));
}

void VideoDecoder::initializeFilterGraph(
void VideoDecoder::createFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth) {
FilterState& filterState = streamInfo.filterState;
if (filterState.filterGraph) {
return;
}

filterState.filterGraph.reset(avfilter_graph_alloc());
TORCH_CHECK(filterState.filterGraph.get() != nullptr);

if (streamInfo.options.ffmpegThreadCount.has_value()) {
filterState.filterGraph->nb_threads =
streamInfo.options.ffmpegThreadCount.value();
Expand Down Expand Up @@ -921,12 +919,32 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(

torch::Tensor outputTensor;
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(frame->format);
auto frameContext = DecodedFrameContext{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};

if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));

if (!streamInfo.swsContext || streamInfo.prevFrame != frameContext) {
createSwsContext(streamInfo, frameContext, frame->colorspace);
streamInfo.prevFrame = frameContext;
}
int resultHeight =
convertFrameToBufferUsingSwsScale(streamIndex, frame, outputTensor);
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
// TODO: Can we do the same check for width?
Expand All @@ -941,16 +959,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
// Note that is a lazy init; we initialize filtergraph the first time
// we have a raw decoded frame. We do this lazily because up until this
// point, we really don't know what the resolution of the frames are
// without modification. In theory, we should be able to get that from the
// stream metadata, but in practice, we have encountered videos where the
// stream metadata had a different resolution from the actual resolution
// of the raw decoded frames.
if (!streamInfo.filterState.filterGraph) {
initializeFilterGraph(
if (!streamInfo.filterState.filterGraph ||
streamInfo.prevFrame != frameContext) {
createFilterGraph(
streamInfo, expectedOutputHeight, expectedOutputWidth);
streamInfo.prevFrame = frameContext;
}
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);

Expand Down Expand Up @@ -1351,7 +1364,53 @@ double VideoDecoder::getPtsSecondsForFrame(
return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase);
}

int VideoDecoder::convertFrameToBufferUsingSwsScale(
void VideoDecoder::createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
const enum AVColorSpace colorspace) {
SwsContext* swsContext = sws_getContext(
frameContext.decodedWidth,
frameContext.decodedHeight,
frameContext.decodedFormat,
frameContext.expectedWidth,
frameContext.expectedHeight,
AV_PIX_FMT_RGB24,
SWS_BILINEAR,
nullptr,
nullptr,
nullptr);
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");

int* invTable = nullptr;
int* table = nullptr;
int srcRange, dstRange, brightness, contrast, saturation;
int ret = sws_getColorspaceDetails(
swsContext,
&invTable,
&srcRange,
&table,
&dstRange,
&brightness,
&contrast,
&saturation);
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");

const int* colorspaceTable = sws_getCoefficients(colorspace);
ret = sws_setColorspaceDetails(
swsContext,
colorspaceTable,
srcRange,
colorspaceTable,
dstRange,
brightness,
contrast,
saturation);
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");

streamInfo.swsContext.reset(swsContext);
}

int VideoDecoder::convertFrameToTensorUsingSwsScale(
int streamIndex,
const AVFrame* frame,
torch::Tensor& outputTensor) {
Expand All @@ -1361,50 +1420,6 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale(

int expectedOutputHeight = outputTensor.sizes()[0];
int expectedOutputWidth = outputTensor.sizes()[1];
auto curFrameSwsContextKey = SwsContextKey{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};
if (activeStream.swsContext.get() == nullptr ||
activeStream.swsContextKey != curFrameSwsContextKey) {
SwsContext* swsContext = sws_getContext(
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight,
AV_PIX_FMT_RGB24,
SWS_BILINEAR,
nullptr,
nullptr,
nullptr);
int* invTable = nullptr;
int* table = nullptr;
int srcRange, dstRange, brightness, contrast, saturation;
sws_getColorspaceDetails(
swsContext,
&invTable,
&srcRange,
&table,
&dstRange,
&brightness,
&contrast,
&saturation);
const int* colorspaceTable = sws_getCoefficients(frame->colorspace);
sws_setColorspaceDetails(
swsContext,
colorspaceTable,
srcRange,
colorspaceTable,
dstRange,
brightness,
contrast,
saturation);
activeStream.swsContextKey = curFrameSwsContextKey;
activeStream.swsContext.reset(swsContext);
}
SwsContext* swsContext = activeStream.swsContext.get();
uint8_t* pointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
Expand All @@ -1428,10 +1443,12 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error("Failed to add frame to buffer source context");
}

UniqueAVFrame filteredFrame(av_frame_alloc());
ffmpegStatus =
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);

auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get());
int height = frameDims.height;
int width = frameDims.width;
Expand All @@ -1441,9 +1458,8 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
auto deleter = [filteredFramePtr](void*) {
UniqueAVFrame frameToDelete(filteredFramePtr);
};
torch::Tensor tensor = torch::from_blob(
return torch::from_blob(
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
return tensor;
}

VideoDecoder::~VideoDecoder() {
Expand Down
20 changes: 12 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ class VideoDecoder {
AVFilterContext* sourceContext = nullptr;
AVFilterContext* sinkContext = nullptr;
};
struct SwsContextKey {
struct DecodedFrameContext {
int decodedWidth;
int decodedHeight;
AVPixelFormat decodedFormat;
int outputWidth;
int outputHeight;
bool operator==(const SwsContextKey&);
bool operator!=(const SwsContextKey&);
int expectedWidth;
int expectedHeight;
bool operator==(const DecodedFrameContext&);
bool operator!=(const DecodedFrameContext&);
};
// Stores information for each stream.
struct StreamInfo {
Expand All @@ -342,7 +342,7 @@ class VideoDecoder {
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
std::vector<FrameInfo> keyFrames;
std::vector<FrameInfo> allFrames;
SwsContextKey swsContextKey;
DecodedFrameContext prevFrame;
UniqueSwsContext swsContext;
};
// Returns the key frame index of the presentation timestamp using FFMPEG's
Expand Down Expand Up @@ -371,10 +371,14 @@ class VideoDecoder {
void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex);
// Creates and initializes a filter graph for a stream. The filter graph can
// do rescaling and color conversion.
void initializeFilterGraph(
void createFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth);
void createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
const enum AVColorSpace colorspace);
void maybeSeekToBeforeDesiredPts();
RawDecodedOutput getDecodedOutputWithFilter(
std::function<bool(int, AVFrame*)>);
Expand All @@ -389,7 +393,7 @@ class VideoDecoder {
torch::Tensor convertFrameToTensorUsingFilterGraph(
int streamIndex,
const AVFrame* frame);
int convertFrameToBufferUsingSwsScale(
int convertFrameToTensorUsingSwsScale(
int streamIndex,
const AVFrame* frame,
torch::Tensor& outputTensor);
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.4a0
0.1.2a0
Loading