diff --git a/inference.py b/inference.py index a116754..c47e4bb 100644 --- a/inference.py +++ b/inference.py @@ -20,6 +20,8 @@ from tqdm.auto import tqdm from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter +from inference_utils import ImageReader, ConstantImage + def convert_video(model, input_source: str, @@ -27,6 +29,7 @@ def convert_video(model, downsample_ratio: Optional[float] = None, output_type: str = 'video', output_composition: Optional[str] = None, + bgr_source: Optional[str] = None, output_alpha: Optional[str] = None, output_foreground: Optional[str] = None, output_video_mbps: Optional[float] = None, @@ -46,6 +49,8 @@ def convert_video(model, The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'. If output_type == 'video', the composition has green screen background. If output_type == 'png_sequence'. the composition is RGBA png images. + bgr_source: A video file, image sequence directory, or an individual image. + This is only applicable if you choose output_type == video. output_alpha: The alpha output from the model. output_foreground: The foreground output from the model. seq_chunk: Number of frames to process at once. Increase it for better parallelism. @@ -110,16 +115,24 @@ def convert_video(model, param = next(model.parameters()) dtype = param.dtype device = param.device - + if (output_composition is not None) and (output_type == 'video'): - bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) - + if bgr_source is not None and os.path.isfile(bgr_source): + if os.path.isfile(bgr_source): + if os.path.splitext(bgr_source)[-1].lower() in [".png", ".jpg"]: + bgr_raw = ImageReader(bgr_source, transform=transform) + else: + bgr_raw = VideoReader(bgr_source, transform) + else: + bgr_raw = ImageSequenceReader(bgr_source, transform) + else: + bgr_raw = ConstantImage(120, 255, 155, device=device, dtype=dtype) + try: with torch.no_grad(): bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True) rec = [None] * 4 - for src in reader: - + for index, src in enumerate(reader): if downsample_ratio is None: downsample_ratio = auto_downsample_ratio(*src.shape[2:]) @@ -132,6 +145,7 @@ def convert_video(model, writer_pha.write(pha[0]) if output_composition is not None: if output_type == 'video': + bgr = bgr_raw[index].to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W] com = fgr * pha + bgr * (1 - pha) else: fgr = fgr * pha.gt(0) diff --git a/inference_utils.py b/inference_utils.py index 1e92fd4..68227a3 100644 --- a/inference_utils.py +++ b/inference_utils.py @@ -2,8 +2,9 @@ import os import pims import numpy as np +import torch from torch.utils.data import Dataset -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import to_pil_image, pil_to_tensor from PIL import Image @@ -34,7 +35,7 @@ def __init__(self, path, frame_rate, bit_rate=1000000): self.stream = self.container.add_stream('h264', rate=round(frame_rate)) self.stream.pix_fmt = 'yuv420p' self.stream.bit_rate = bit_rate - + def write(self, frames): # frames: [T, C, H, W] self.stream.width = frames.size(3) @@ -46,12 +47,39 @@ def write(self, frames): frame = frames[t] frame = av.VideoFrame.from_ndarray(frame, format='rgb24') self.container.mux(self.stream.encode(frame)) - + def close(self): self.container.mux(self.stream.encode()) self.container.close() +class ImageReader(Dataset): + def __init__(self, path, transform=None): + self.path = path + self.transform = transform + + def __len__(self): + return 1 + + def __getitem__(self, idx): + with Image.open(self.path) as img: + img.load() + if self.transform is not None: + return self.transform(img) + return img + + +class ConstantImage(Dataset): + def __init__(self, r, g, b, device=None, dtype=None): + self.tensor = torch.tensor([r, g, b], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) + + def __len__(self): + return 1 + + def __getitem__(self, idx): + return self.tensor + + class ImageSequenceReader(Dataset): def __init__(self, path, transform=None): self.path = path @@ -85,4 +113,3 @@ def write(self, frames): def close(self): pass - \ No newline at end of file