Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support writing the audio stream back into video #83

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 55 additions & 24 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
--seq-chunk 1
"""

import av
import torch
import os
from torch.utils.data import DataLoader
Expand All @@ -20,6 +21,8 @@
from tqdm.auto import tqdm

from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
from inference_utils import AudioVideoWriter


def convert_video(model,
input_source: str,
Expand All @@ -33,6 +36,7 @@ def convert_video(model,
seq_chunk: int = 1,
num_workers: int = 0,
progress: bool = True,
passthrough_audio: bool = True,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):

Expand All @@ -51,10 +55,11 @@ def convert_video(model,
seq_chunk: Number of frames to process at once. Increase it for better parallelism.
num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
progress: Show progress bar.
passthrough_audio: Should we passthrough any audio from the input video
device: Only need to manually provide if model is a TorchScript freezed model.
dtype: Only need to manually provide if model is a TorchScript freezed model.
"""

assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.'
assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.'
Expand All @@ -76,26 +81,52 @@ def convert_video(model,
else:
source = ImageSequenceReader(input_source, transform)
reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)


audio_source = None
if os.path.isfile(input_source):
container = av.open(input_source)
if container.streams.get(audio=0):
audio_source = container.streams.get(audio=0)[0]

# Initialize writers
if output_type == 'video':
frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
if output_composition is not None:
writer_com = VideoWriter(
path=output_composition,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_alpha is not None:
writer_pha = VideoWriter(
path=output_alpha,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_foreground is not None:
writer_fgr = VideoWriter(
path=output_foreground,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if passthrough_audio and audio_source:
if output_composition is not None:
writer_com = AudioVideoWriter(
path=output_composition,
frame_rate=frame_rate,
audio_stream=audio_source,
bit_rate=int(output_video_mbps * 1000000))
if output_alpha is not None:
writer_pha = AudioVideoWriter(
path=output_alpha,
frame_rate=frame_rate,
audio_stream=audio_source,
bit_rate=int(output_video_mbps * 1000000))
if output_foreground is not None:
writer_fgr = AudioVideoWriter(
path=output_foreground,
frame_rate=frame_rate,
audio_stream=audio_source,
bit_rate=int(output_video_mbps * 1000000))
else:
if output_composition is not None:
writer_com = VideoWriter(
path=output_composition,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_alpha is not None:
writer_pha = VideoWriter(
path=output_alpha,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
if output_foreground is not None:
writer_fgr = VideoWriter(
path=output_foreground,
frame_rate=frame_rate,
bit_rate=int(output_video_mbps * 1000000))
else:
if output_composition is not None:
writer_com = ImageSequenceWriter(output_composition, 'png')
Expand All @@ -113,7 +144,7 @@ def convert_video(model,

if (output_composition is not None) and (output_type == 'video'):
bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)

try:
with torch.no_grad():
bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
Expand All @@ -137,7 +168,7 @@ def convert_video(model,
fgr = fgr * pha.gt(0)
com = torch.cat([fgr, pha], dim=-3)
writer_com.write(com[0])

bar.update(src.size(1))

finally:
Expand Down Expand Up @@ -167,11 +198,12 @@ def __init__(self, variant: str, checkpoint: str, device: str):

def convert(self, *args, **kwargs):
convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)



if __name__ == '__main__':
import argparse
from model import MattingNetwork

parser = argparse.ArgumentParser()
parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
parser.add_argument('--checkpoint', type=str, required=True)
Expand All @@ -188,7 +220,7 @@ def convert(self, *args, **kwargs):
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--disable-progress', action='store_true')
args = parser.parse_args()

converter = Converter(args.variant, args.checkpoint, args.device)
converter.convert(
input_source=args.input_source,
Expand All @@ -203,5 +235,4 @@ def convert(self, *args, **kwargs):
num_workers=args.num_workers,
progress=not args.disable_progress
)



42 changes: 34 additions & 8 deletions inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ def __init__(self, path, transform=None):
self.video = pims.PyAVVideoReader(path)
self.rate = self.video.frame_rate
self.transform = transform

@property
def frame_rate(self):
return self.rate

def __len__(self):
return len(self.video)

def __getitem__(self, idx):
frame = self.video[idx]
frame = Image.fromarray(np.asarray(frame))
Expand Down Expand Up @@ -57,10 +57,10 @@ def __init__(self, path, transform=None):
self.path = path
self.files = sorted(os.listdir(path))
self.transform = transform

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img:
img.load()
Expand All @@ -75,14 +75,40 @@ def __init__(self, path, extension='jpg'):
self.extension = extension
self.counter = 0
os.makedirs(path, exist_ok=True)

def write(self, frames):
# frames: [T, C, H, W]
for t in range(frames.shape[0]):
to_pil_image(frames[t]).save(os.path.join(
self.path, str(self.counter).zfill(4) + '.' + self.extension))
self.counter += 1

def close(self):
pass



class AudioVideoWriter(VideoWriter):
def __init__(self, path, frame_rate, audio_stream=None, bit_rate=1000000):
super(AudioVideoWriter, self).__init__(
path=path,
frame_rate=frame_rate,
bit_rate=bit_rate
)
self.source_audio_stream = audio_stream
self.output_audio_stream = self.container.add_stream(
codec_name=self.source_audio_stream.codec_context.codec.name,
rate=self.source_audio_stream.rate,
)

def remux_audio(self):
input_audio_container = self.source_audio_stream.container
for packet in input_audio_container.demux(self.source_audio_stream):
if packet.dts is None:
continue
packet.stream = self.output_audio_stream
self.container.mux(packet)

def close(self):
self.remux_audio()
self.container.mux(self.output_audio_stream.encode())
super(AudioVideoWriter, self).close()