Skip to content

Commit

Permalink
Rework experimental video reader to avoid duplication + add enable_ti…
Browse files Browse the repository at this point in the history
…mestamps

Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Mar 5, 2025
1 parent 1ee7b46 commit 08da1f2
Show file tree
Hide file tree
Showing 16 changed files with 441 additions and 732 deletions.
2 changes: 1 addition & 1 deletion dali/operators/video/decoder/video_decoder_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class DLL_PUBLIC VideoDecoderBase : public Operator<Backend> {
if constexpr (std::is_same_v<Backend, CPUBackend>) {
return std::make_unique<FramesDecoderImpl>(data, size, build_index, -1, source_info);
} else {
return std::make_unique<FramesDecoderImpl>(data, size, stream, build_index, -1, source_info);
return std::make_unique<FramesDecoderImpl>(data, size, build_index, -1, source_info, stream);
}
}

Expand Down
10 changes: 10 additions & 0 deletions dali/operators/video/frames_decoder_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ class DLL_PUBLIC FramesDecoderBase {
*/
bool HasIndex() const { return !index_.empty(); }


/**
* @brief Check if a codec is supported by the particular implementation.
*
* @param codec_id Codec ID to check.
* @return True if the codec is supported, false otherwise.
*/
virtual bool CanDecode(AVCodecID codec_id) const = 0;

virtual ~FramesDecoderBase() = default;
FramesDecoderBase(FramesDecoderBase&&) = default;
FramesDecoderBase& operator=(FramesDecoderBase&&) = default;
Expand All @@ -231,6 +240,7 @@ class DLL_PUBLIC FramesDecoderBase {
}

const IndexEntry& Index(int frame_id) const {
assert(HasIndex());
return index_[frame_id];
}

Expand Down
2 changes: 1 addition & 1 deletion dali/operators/video/frames_decoder_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DLL_PUBLIC FramesDecoderCpu : public FramesDecoderBase {
* @param codec_id Codec ID to check.
* @return True if the codec is supported, false otherwise.
*/
bool CanDecode(AVCodecID codec_id) const;
bool CanDecode(AVCodecID codec_id) const override;
};

} // namespace dali
Expand Down
16 changes: 8 additions & 8 deletions dali/operators/video/frames_decoder_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,11 @@ void FramesDecoderGpu::InitGpuParser() {
}
}

FramesDecoderGpu::FramesDecoderGpu(const std::string &filename, cudaStream_t stream) :
FramesDecoderBase(filename, true, false),
frame_buffer_(num_decode_surfaces_),
stream_(stream) {
FramesDecoderGpu::FramesDecoderGpu(const std::string &filename, bool build_index,
cudaStream_t stream)
: FramesDecoderBase(filename, build_index, false),
frame_buffer_(num_decode_surfaces_),
stream_(stream) {
if (is_valid_ && CanDecode(codec_params_->codec_id)) {
InitGpuParser();
} else {
Expand All @@ -429,10 +430,9 @@ FramesDecoderGpu::FramesDecoderGpu(const std::string &filename, cudaStream_t str
}

FramesDecoderGpu::FramesDecoderGpu(const char *memory_file, size_t memory_file_size,
cudaStream_t stream, bool build_index, int num_frames,
std::string_view source_info)
: FramesDecoderBase(memory_file, memory_file_size, build_index, false, num_frames,
source_info),
bool build_index, int num_frames, std::string_view source_info,
cudaStream_t stream)
: FramesDecoderBase(memory_file, memory_file_size, build_index, false, num_frames, source_info),
frame_buffer_(num_decode_surfaces_),
stream_(stream) {
if (is_valid_ && CanDecode(codec_params_->codec_id)) {
Expand Down
14 changes: 8 additions & 6 deletions dali/operators/video/frames_decoder_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ class DLL_PUBLIC FramesDecoderGpu : public FramesDecoderBase {
* @brief Construct a new FramesDecoder object.
*
* @param filename Path to a video file.
* @param stream Stream used for decode processing.
* @param build_index If set to false index will not be build and some features are unavailable.
* @param stream CUDA stream to use for decoding.
*/
explicit FramesDecoderGpu(const std::string &filename, cudaStream_t stream = 0);
explicit FramesDecoderGpu(const std::string &filename, bool build_index = true, cudaStream_t stream = 0);

/**
* @brief Construct a new FramesDecoder object.
Expand All @@ -141,12 +142,13 @@ class DLL_PUBLIC FramesDecoderGpu : public FramesDecoderBase {
* @param memory_file_size Size of memory_file in bytes.
* @param build_index If set to false index will not be build and some features are unavailable.
* @param num_frames If set, number of frames in the video.
* @param stream CUDA stream to use for decoding.
*
* @note This constructor assumes that the `memory_file` and
* `memory_file_size` arguments cover the entire video file, including the header.
*/
FramesDecoderGpu(const char *memory_file, size_t memory_file_size, cudaStream_t stream = 0,
bool build_index = true, int num_frames = -1, std::string_view = {});
FramesDecoderGpu(const char *memory_file, size_t memory_file_size, bool build_index = true,
int num_frames = -1, std::string_view = {}, cudaStream_t stream = 0);

bool ReadNextFrame(uint8_t *data) override;

Expand Down Expand Up @@ -174,7 +176,7 @@ class DLL_PUBLIC FramesDecoderGpu : public FramesDecoderBase {
* @param codec_id Codec ID to check.
* @return True if the codec is supported, false otherwise.
*/
bool CanDecode(AVCodecID codec_id) const;
bool CanDecode(AVCodecID codec_id) const override;

private:
std::unique_ptr<NvDecodeState> nvdecode_state_;
Expand All @@ -197,7 +199,7 @@ class DLL_PUBLIC FramesDecoderGpu : public FramesDecoderBase {

std::queue<int> piped_pts_;

cudaStream_t stream_;
cudaStream_t stream_ = 0;

void SendLastPacket(bool flush = false);

Expand Down
20 changes: 10 additions & 10 deletions dali/operators/video/frames_decoder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ TEST_F(FramesDecoderTest_CpuOnlyTests, NoIndexSeek) {

TEST_F(FramesDecoderGpuTest, VariableFrameRateNoIndex) {
auto memory_video = MemoryVideo(vfr_videos_paths_[0]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, vfr_videos_[0], false);
}

Expand All @@ -314,57 +314,57 @@ TEST_F(FramesDecoderGpuTest, VariableFrameRateHevcNoIndex) {
GTEST_SKIP();
}
auto memory_video = MemoryVideo(vfr_hevc_videos_paths_[1]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, vfr_hevc_videos_[1], false);
}

TEST_F(FramesDecoderGpuTest, CfrFrameRateMpeg4NoIndex) {
auto memory_video = MemoryVideo(cfr_mpeg4_videos_paths_[0]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, cfr_videos_[0], false, 3.0);
}

TEST_F(FramesDecoderGpuTest, VfrFrameRateMpeg4NoIndex) {
auto memory_video = MemoryVideo(vfr_mpeg4_videos_paths_[0]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, vfr_videos_[0], false, 3.0);
}

TEST_F(FramesDecoderGpuTest, CfrFrameRateMpeg4MkvNoIndex) {
auto memory_video = MemoryVideo(cfr_mpeg4_mkv_videos_paths_[0]);
FramesDecoderGpu decoder(
memory_video.data(), memory_video.size(), 0, false, cfr_videos_[0].NumFrames());
memory_video.data(), memory_video.size(), false, cfr_videos_[0].NumFrames());
RunTest(decoder, cfr_videos_[0], false, 3.0);
}

TEST_F(FramesDecoderGpuTest, CfrFrameRateMpeg4MkvNoIndexNoFrameNum) {
auto memory_video = MemoryVideo(cfr_mpeg4_mkv_videos_paths_[0]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, cfr_videos_[0], false, 3.0);
}

TEST_F(FramesDecoderGpuTest, VfrFrameRateMpeg4MkvNoIndex) {
auto memory_video = MemoryVideo(vfr_mpeg4_mkv_videos_paths_[1]);
FramesDecoderGpu decoder(
memory_video.data(), memory_video.size(), 0, false, vfr_videos_[1].NumFrames());
memory_video.data(), memory_video.size(), false, vfr_videos_[1].NumFrames());
RunTest(decoder, vfr_videos_[1], false, 3.0);
}

TEST_F(FramesDecoderGpuTest, VfrFrameRateMpeg4MkvNoIndexNoFrameNum) {
auto memory_video = MemoryVideo(vfr_mpeg4_mkv_videos_paths_[1]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, vfr_videos_[1], false, 3.0);
}

TEST_F(FramesDecoderGpuTest, RawH264) {
auto memory_video = MemoryVideo(cfr_raw_h264_videos_paths_[1]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, cfr_videos_[1], false, 1.5);
}

TEST_F(FramesDecoderGpuTest, RawH265) {
auto memory_video = MemoryVideo(cfr_raw_h264_videos_paths_[0]);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), 0, false);
FramesDecoderGpu decoder(memory_video.data(), memory_video.size(), false);
RunTest(decoder, cfr_videos_[0], false, 1.5);
}

Expand Down
4 changes: 2 additions & 2 deletions dali/operators/video/input/video_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ class VideoInput : public InputOperator<Backend> {
std::unique_ptr<FramesDecoderImpl> CreateDecoder(const char *data, size_t size,
cudaStream_t stream = 0) {
if constexpr (std::is_same_v<Backend, CPUBackend>) {
return std::make_unique<FramesDecoderImpl>(data, size, false, -1);
return std::make_unique<FramesDecoderImpl>(data, size, false, -1, std::string_view{});
} else {
return std::make_unique<FramesDecoderImpl>(data, size, stream, false, -1);
return std::make_unique<FramesDecoderImpl>(data, size, false, -1, std::string_view{}, stream);
}
}

Expand Down
156 changes: 156 additions & 0 deletions dali/operators/video/reader/video_loader_decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_OPERATORS_READER_LOADER_VIDEO_VIDEO_LOADER_DECODER_BASE_H_
#define DALI_OPERATORS_READER_LOADER_VIDEO_VIDEO_LOADER_DECODER_BASE_H_


#include <string>
#include <vector>
#include "dali/core/cuda_stream_pool.h"
#include "dali/operators/video/frames_decoder_base.h"
#include "dali/operators/video/frames_decoder_gpu.h"
#include "dali/operators/video/frames_decoder_cpu.h"

namespace dali {

struct VideoSampleDesc {
VideoSampleDesc(std::string filename = {}, int label = -1, int start = -1, int end = -1,
int stride = -1)
: filename_(filename), label_(label), start_(start), end_(end), stride_(stride) {}
std::string filename_;
int label_ = -1;
int start_ = -1;
int end_ = -1;
int stride_ = -1;
};

template <typename Backend>
struct VideoSample : public VideoSampleDesc {
VideoSample(std::string filename = {}, int label = -1, int start = -1, int end = -1,
int stride = -1)
: VideoSampleDesc{filename, label, start, end, stride} {}

VideoSample(const VideoSampleDesc &other) noexcept
: VideoSampleDesc(other) {}

// to be filled by Prefetch
Tensor<Backend> data_;
int64_t start_timestamp_ = -1;
};

template <typename Backend, typename FramesDecoderImpl, typename Sample = VideoSample<Backend>>
class VideoLoaderDecoder : public Loader<Backend, Sample, true> {
public:
explicit inline VideoLoaderDecoder(const OpSpec &spec):
Loader<Backend, Sample, true>(spec),
filenames_(spec.GetRepeatedArgument<std::string>("filenames")),
sequence_len_(spec.GetArgument<int>("sequence_length")),
stride_(spec.GetArgument<int>("stride")),
step_(spec.GetArgument<int>("step")) {
has_labels_ = spec.TryGetRepeatedArgument(labels_, "labels");
DALI_ENFORCE(
!has_labels_ || labels_.size() == filenames_.size(),
make_string(
"Number of provided files and labels should match. Provided ",
filenames_.size(), " files and ", labels_.size(), " labels."));
if (step_ <= 0) {
step_ = stride_ * sequence_len_;
}
}

void PrepareEmpty(Sample &sample) {
sample = Sample();
}

void ReadSample(Sample &sample) override {
sample = Sample(samples_[current_index_]);
MoveToNextShard(++current_index_);
}

void Skip() override {
MoveToNextShard(++current_index_);
}

Index SizeImpl() override {
return samples_.size();
}

void PrepareMetadataImpl() override {
std::unique_ptr<FramesDecoderImpl> decoder;
for (size_t i = 0; i < filenames_.size(); ++i) {
const auto &filename = filenames_[i];
int label = has_labels_ ? labels_[i] : -1;
decoder = std::make_unique<FramesDecoderImpl>(filename, true);
if (!decoder->IsValid()) {
LOG_LINE << "Invalid video file: " << filename << std::endl;
continue;
}
int64_t num_frames = decoder->NumFrames();
for (int start = 0; start + stride_ * sequence_len_ <= num_frames;
start += step_) {
LOG_LINE << "Sample #" << samples_.size() << ": " << filename << " " << label << " "
<< start << ".." << start + stride_ * sequence_len_ << std::endl;
samples_.emplace_back(filename, label, start, start + stride_ * sequence_len_, stride_);
}
}

if (shuffle_) {
// seeded with hardcoded value to get
// the same sequence on every shard
std::mt19937 g(kDaliDataloaderSeed);
std::shuffle(std::begin(samples_), std::end(samples_), g);
}

// set the initial index for each shard
Reset(true);
}

void Reset(bool wrap_to_shard) override {
current_index_ = wrap_to_shard ? start_index(virtual_shard_id_, num_shards_, SizeImpl()) : 0;
}

protected:
using Base = Loader<Backend, Sample, true>;
using Base::shard_id_;
using Base::virtual_shard_id_;
using Base::num_shards_;
using Base::stick_to_shard_;
using Base::shuffle_;
using Base::dont_use_mmap_;
using Base::initial_buffer_fill_;
using Base::copy_read_data_;
using Base::read_ahead_;
using Base::IsCheckpointingEnabled;
using Base::PrepareEmptyTensor;
using Base::MoveToNextShard;
using Base::ShouldSkipImage;

std::vector<std::string> filenames_;
std::vector<int> labels_;
bool has_labels_ = false;

Index current_index_ = 0;

int sequence_len_;
int stride_;
int step_;

std::vector<VideoSampleDesc> samples_;
CUDAStreamLease cuda_stream_;
};

} // namespace dali

#endif // DALI_OPERATORS_READER_LOADER_VIDEO_VIDEO_LOADER_DECODER_BASE_H_
Loading

0 comments on commit 08da1f2

Please sign in to comment.