Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to provide a custom video background #84

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
from tqdm.auto import tqdm

from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
from inference_utils import ImageReader, ConstantImage


def convert_video(model,
input_source: str,
input_resize: Optional[Tuple[int, int]] = None,
downsample_ratio: Optional[float] = None,
output_type: str = 'video',
output_composition: Optional[str] = None,
bgr_source: Optional[str] = None,
output_alpha: Optional[str] = None,
output_foreground: Optional[str] = None,
output_video_mbps: Optional[float] = None,
Expand All @@ -46,6 +49,8 @@ def convert_video(model,
The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'.
If output_type == 'video', the composition has green screen background.
If output_type == 'png_sequence'. the composition is RGBA png images.
bgr_source: A video file, image sequence directory, or an individual image.
This is only applicable if you choose output_type == video.
output_alpha: The alpha output from the model.
output_foreground: The foreground output from the model.
seq_chunk: Number of frames to process at once. Increase it for better parallelism.
Expand Down Expand Up @@ -110,16 +115,24 @@ def convert_video(model,
param = next(model.parameters())
dtype = param.dtype
device = param.device

if (output_composition is not None) and (output_type == 'video'):
bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)

if bgr_source is not None and os.path.isfile(bgr_source):
if os.path.isfile(bgr_source):
if os.path.splitext(bgr_source)[-1].lower() in [".png", ".jpg"]:
bgr_raw = ImageReader(bgr_source, transform=transform)
else:
bgr_raw = VideoReader(bgr_source, transform)
else:
bgr_raw = ImageSequenceReader(bgr_source, transform)
else:
bgr_raw = ConstantImage(120, 255, 155, device=device, dtype=dtype)

try:
with torch.no_grad():
bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
rec = [None] * 4
for src in reader:

for index, src in enumerate(reader):
if downsample_ratio is None:
downsample_ratio = auto_downsample_ratio(*src.shape[2:])

Expand All @@ -132,6 +145,7 @@ def convert_video(model,
writer_pha.write(pha[0])
if output_composition is not None:
if output_type == 'video':
bgr = bgr_raw[index].to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W]
com = fgr * pha + bgr * (1 - pha)
else:
fgr = fgr * pha.gt(0)
Expand Down
35 changes: 31 additions & 4 deletions inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
import pims
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
from PIL import Image


Expand Down Expand Up @@ -34,7 +35,7 @@ def __init__(self, path, frame_rate, bit_rate=1000000):
self.stream = self.container.add_stream('h264', rate=round(frame_rate))
self.stream.pix_fmt = 'yuv420p'
self.stream.bit_rate = bit_rate

def write(self, frames):
# frames: [T, C, H, W]
self.stream.width = frames.size(3)
Expand All @@ -46,12 +47,39 @@ def write(self, frames):
frame = frames[t]
frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
self.container.mux(self.stream.encode(frame))

def close(self):
self.container.mux(self.stream.encode())
self.container.close()


class ImageReader(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.transform = transform

def __len__(self):
return 1

def __getitem__(self, idx):
with Image.open(self.path) as img:
img.load()
if self.transform is not None:
return self.transform(img)
return img


class ConstantImage(Dataset):
def __init__(self, r, g, b, device=None, dtype=None):
self.tensor = torch.tensor([r, g, b], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)

def __len__(self):
return 1

def __getitem__(self, idx):
return self.tensor


class ImageSequenceReader(Dataset):
def __init__(self, path, transform=None):
self.path = path
Expand Down Expand Up @@ -85,4 +113,3 @@ def write(self, frames):

def close(self):
pass