From 0d0ac222afc34322cb7a5f22be861c674e6b7ec4 Mon Sep 17 00:00:00 2001 From: h9419 Date: Wed, 6 Apr 2022 21:47:42 +0800 Subject: [PATCH 1/8] Improving performance of video inference by reducing GPU memory copy and threading CPU video encoding Two improvements are made in this fork: 1. Removed repeated copying of background to GPU memory 2. Minimized idle GPU time by passing video encoding work to children threads as soon as it is copied to CPU memory, allowing for higher GPU utilization. --- inference_video.py | 54 +++++++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/inference_video.py b/inference_video.py index 1fa8cd4..5e9bfd6 100644 --- a/inference_video.py +++ b/inference_video.py @@ -30,11 +30,12 @@ from torch.utils.data import DataLoader from torchvision import transforms as T from torchvision.transforms.functional import to_pil_image +from multiprocessing import Process, Pipe from threading import Thread from tqdm import tqdm from PIL import Image -from dataset import VideoDataset, ZipDataset +from dataset import VideoDataset from dataset import augmentation as A from model import MattingBase, MattingRefine from inference_utils import HomographicAlignment @@ -79,15 +80,26 @@ class VideoWriter: def __init__(self, path, frame_rate, width, height): - self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height)) + output_p, input_p = Pipe() + Process(target=self.VideoWriterWorker, args=(path, frame_rate, width, height, (output_p, input_p))).start() + output_p.close() + self.input_p = input_p def add_batch(self, frames): - frames = frames.mul(255).byte() - frames = frames.cpu().permute(0, 2, 3, 1).numpy() - for i in range(frames.shape[0]): - frame = frames[i] - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - self.out.write(frame) + frames = frames.mul(255).byte().permute(0, 2, 3, 1) + self.input_p.send(frames.cpu()) + + @staticmethod + def VideoWriterWorker(path, frame_rate, width, height, pipe): + output_p, input_p = pipe + input_p.close() + out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height)) + while True: + frames = output_p.recv().numpy() + for i in range(frames.shape[0]): + frame = frames[i] + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + out.write(frame) class ImageSequenceWriter: @@ -132,12 +144,16 @@ def _add_batch(self, frames, index): # Load video and background vid = VideoDataset(args.video_src) -bgr = [Image.open(args.video_bgr).convert('RGB')] -dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([ - A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()), - HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), - A.PairApply(T.ToTensor()) -])) +bgr = Image.open(args.video_bgr).convert('RGB') + +transforms = T.Compose([ + T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity(), + T.ToTensor() +]) + +bgr = transforms(bgr) +dataset = VideoDataset(args.video_src, transforms=transforms) + if args.video_target_bgr: dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())]) @@ -179,15 +195,17 @@ def _add_batch(self, frames, index): # Conversion loop with torch.no_grad(): + # move background to device + bgr = (bgr[None]).to(device, non_blocking=False) for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): if args.video_target_bgr: - (src, bgr), tgt_bgr = input_batch + src, tgt_bgr = input_batch tgt_bgr = tgt_bgr.to(device, non_blocking=True) else: - src, bgr = input_batch + src = input_batch tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1) + # move frame to device src = src.to(device, non_blocking=True) - bgr = bgr.to(device, non_blocking=True) if args.model_type == 'mattingbase': pha, fgr, err, _ = model(src, bgr) @@ -213,3 +231,5 @@ def _add_batch(self, frames, index): err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) if 'ref' in args.output_types: ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) + # exit parent and all children processes + exit(0) From 7a716a5ba635f836e085ab8b790592ef1cf1e289 Mon Sep 17 00:00:00 2001 From: h9419 Date: Wed, 6 Apr 2022 22:54:24 +0800 Subject: [PATCH 2/8] Added termination and reaping of children processes --- inference_video.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/inference_video.py b/inference_video.py index 5e9bfd6..6d6caeb 100644 --- a/inference_video.py +++ b/inference_video.py @@ -81,7 +81,8 @@ class VideoWriter: def __init__(self, path, frame_rate, width, height): output_p, input_p = Pipe() - Process(target=self.VideoWriterWorker, args=(path, frame_rate, width, height, (output_p, input_p))).start() + self.worker = Process(target=self.VideoWriterWorker, args=(path, frame_rate, width, height, (output_p, input_p))) + self.worker.start() output_p.close() self.input_p = input_p @@ -89,13 +90,22 @@ def add_batch(self, frames): frames = frames.mul(255).byte().permute(0, 2, 3, 1) self.input_p.send(frames.cpu()) + def close(self): + self.input_p.send(0) + self.worker.join() + @staticmethod def VideoWriterWorker(path, frame_rate, width, height, pipe): output_p, input_p = pipe input_p.close() out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height)) while True: - frames = output_p.recv().numpy() + read_buffer = output_p.recv() + # gracefully exit with provided exit code if it is an integer + if type(read_buffer) == type(int): + out.release() + exit(read_buffer) + frames = read_buffer.numpy() for i in range(frames.shape[0]): frame = frames[i] frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) @@ -231,5 +241,17 @@ def _add_batch(self, frames, index): err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) if 'ref' in args.output_types: ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) - # exit parent and all children processes - exit(0) + # terminate children processes + if args.output_format == 'video': + h = args.video_resize[1] if args.video_resize is not None else vid.height + w = args.video_resize[0] if args.video_resize is not None else vid.width + if 'com' in args.output_types: + com_writer.close() + if 'pha' in args.output_types: + pha_writer.close() + if 'fgr' in args.output_types: + fgr_writer.close() + if 'err' in args.output_types: + err_writer.close() + if 'ref' in args.output_types: + ref_writer.close() From a284cd6b1f55752494e853cc78783313a01711ea Mon Sep 17 00:00:00 2001 From: h9419 <58384315+h9419@users.noreply.github.com> Date: Wed, 6 Apr 2022 23:37:31 +0800 Subject: [PATCH 3/8] Update inference_video.py Breaks the loop instead of exiting directly --- inference_video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference_video.py b/inference_video.py index 6d6caeb..36973c7 100644 --- a/inference_video.py +++ b/inference_video.py @@ -103,13 +103,13 @@ def VideoWriterWorker(path, frame_rate, width, height, pipe): read_buffer = output_p.recv() # gracefully exit with provided exit code if it is an integer if type(read_buffer) == type(int): - out.release() - exit(read_buffer) + break frames = read_buffer.numpy() for i in range(frames.shape[0]): frame = frames[i] frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame) + out.release() class ImageSequenceWriter: From 156b3b77bacc9c4bfc6094b01fba8d51f151093a Mon Sep 17 00:00:00 2001 From: h9419 <58384315+h9419@users.noreply.github.com> Date: Wed, 6 Apr 2022 23:49:09 +0800 Subject: [PATCH 4/8] Update inference_video.py Fixed Replaced type(int) with (int) --- inference_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference_video.py b/inference_video.py index 36973c7..b520b4d 100644 --- a/inference_video.py +++ b/inference_video.py @@ -102,7 +102,7 @@ def VideoWriterWorker(path, frame_rate, width, height, pipe): while True: read_buffer = output_p.recv() # gracefully exit with provided exit code if it is an integer - if type(read_buffer) == type(int): + if type(read_buffer) == int: break frames = read_buffer.numpy() for i in range(frames.shape[0]): From 2851e277d60372abe39110aebe6a23df0f3fce63 Mon Sep 17 00:00:00 2001 From: h9419 <58384315+h9419@users.noreply.github.com> Date: Thu, 7 Apr 2022 07:54:41 +0800 Subject: [PATCH 5/8] Update inference_video.py Fixed imports --- inference_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference_video.py b/inference_video.py index b520b4d..63c8241 100644 --- a/inference_video.py +++ b/inference_video.py @@ -35,7 +35,7 @@ from tqdm import tqdm from PIL import Image -from dataset import VideoDataset +from dataset import VideoDataset, ZipDataset from dataset import augmentation as A from model import MattingBase, MattingRefine from inference_utils import HomographicAlignment From dd2d3a44b5e41e49a3ee102e8b7ecb5948813730 Mon Sep 17 00:00:00 2001 From: h9419 <58384315+h9419@users.noreply.github.com> Date: Thu, 7 Apr 2022 07:57:10 +0800 Subject: [PATCH 6/8] Removed junk code --- inference_video.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/inference_video.py b/inference_video.py index 63c8241..5f36805 100644 --- a/inference_video.py +++ b/inference_video.py @@ -243,8 +243,6 @@ def _add_batch(self, frames, index): ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) # terminate children processes if args.output_format == 'video': - h = args.video_resize[1] if args.video_resize is not None else vid.height - w = args.video_resize[0] if args.video_resize is not None else vid.width if 'com' in args.output_types: com_writer.close() if 'pha' in args.output_types: From 810291038caaacbc7d4e369b0d7ea00ddda33984 Mon Sep 17 00:00:00 2001 From: h9419 <58384315+h9419@users.noreply.github.com> Date: Thu, 7 Apr 2022 09:34:33 +0800 Subject: [PATCH 7/8] Remove CPU decoding bottleneck I found out that the time CPU spent with DataLoader is another 30-40% of the execution time. I added a thread for loading data and reserved the main thread for controlling the GPU. --- inference_video.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/inference_video.py b/inference_video.py index 5f36805..9af63f8 100644 --- a/inference_video.py +++ b/inference_video.py @@ -31,6 +31,7 @@ from torchvision import transforms as T from torchvision.transforms.functional import to_pil_image from multiprocessing import Process, Pipe +from queue import Queue from threading import Thread from tqdm import tqdm from PIL import Image @@ -205,17 +206,28 @@ def _add_batch(self, frames, index): # Conversion loop with torch.no_grad(): + queue = Queue(1) + def load_worker(): + tgt_bgr = torch.tensor([120/255, 255/255, 155/255]).view(1, 3, 1, 1) + for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): + if args.video_target_bgr: + src, tgt_bgr = input_batch + else: + src = input_batch + queue.put((src, tgt_bgr)) + queue.put(None) + loader = Thread(target=load_worker) + loader.start() # move background to device bgr = (bgr[None]).to(device, non_blocking=False) - for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): - if args.video_target_bgr: - src, tgt_bgr = input_batch - tgt_bgr = tgt_bgr.to(device, non_blocking=True) - else: - src = input_batch - tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1) + while True: + task = queue.get() + if task == None: + break + src, tgt_bgr = task # move frame to device src = src.to(device, non_blocking=True) + tgt_bgr = tgt_bgr.to(device, non_blocking=True) if args.model_type == 'mattingbase': pha, fgr, err, _ = model(src, bgr) @@ -242,6 +254,7 @@ def _add_batch(self, frames, index): if 'ref' in args.output_types: ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) # terminate children processes + loader.join() if args.output_format == 'video': if 'com' in args.output_types: com_writer.close() From f6525104b17542e783e270747159f7b40ce6057f Mon Sep 17 00:00:00 2001 From: h9419 <58384315+h9419@users.noreply.github.com> Date: Wed, 13 Apr 2022 13:15:43 +0800 Subject: [PATCH 8/8] Made multiprocessing also work on windows Added if __name__ == '__main__' so that windows recognizes Process --- inference_video.py | 224 ++++++++++++++++++++++----------------------- 1 file changed, 112 insertions(+), 112 deletions(-) diff --git a/inference_video.py b/inference_video.py index 9af63f8..be07cb6 100644 --- a/inference_video.py +++ b/inference_video.py @@ -133,136 +133,136 @@ def _add_batch(self, frames, index): # --------------- Main --------------- +if __name__ == '__main__': + device = torch.device(args.device) -device = torch.device(args.device) + # Load model + if args.model_type == 'mattingbase': + model = MattingBase(args.model_backbone) + if args.model_type == 'mattingrefine': + model = MattingRefine( + args.model_backbone, + args.model_backbone_scale, + args.model_refine_mode, + args.model_refine_sample_pixels, + args.model_refine_threshold, + args.model_refine_kernel_size) -# Load model -if args.model_type == 'mattingbase': - model = MattingBase(args.model_backbone) -if args.model_type == 'mattingrefine': - model = MattingRefine( - args.model_backbone, - args.model_backbone_scale, - args.model_refine_mode, - args.model_refine_sample_pixels, - args.model_refine_threshold, - args.model_refine_kernel_size) + model = model.to(device).eval() + model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) -model = model.to(device).eval() -model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) + # Load video and background + vid = VideoDataset(args.video_src) + bgr = Image.open(args.video_bgr).convert('RGB') -# Load video and background -vid = VideoDataset(args.video_src) -bgr = Image.open(args.video_bgr).convert('RGB') + transforms = T.Compose([ + T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity(), + T.ToTensor() + ]) -transforms = T.Compose([ - T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity(), - T.ToTensor() -]) + bgr = transforms(bgr) + dataset = VideoDataset(args.video_src, transforms=transforms) -bgr = transforms(bgr) -dataset = VideoDataset(args.video_src, transforms=transforms) + if args.video_target_bgr: + dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())]) -if args.video_target_bgr: - dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())]) + # Create output directory + if os.path.exists(args.output_dir): + if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': + shutil.rmtree(args.output_dir) + else: + exit() + os.makedirs(args.output_dir) -# Create output directory -if os.path.exists(args.output_dir): - if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': - shutil.rmtree(args.output_dir) - else: - exit() -os.makedirs(args.output_dir) - - -# Prepare writers -if args.output_format == 'video': - h = args.video_resize[1] if args.video_resize is not None else vid.height - w = args.video_resize[0] if args.video_resize is not None else vid.width - if 'com' in args.output_types: - com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h) - if 'pha' in args.output_types: - pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h) - if 'fgr' in args.output_types: - fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h) - if 'err' in args.output_types: - err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h) - if 'ref' in args.output_types: - ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h) -else: - if 'com' in args.output_types: - com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png') - if 'pha' in args.output_types: - pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg') - if 'fgr' in args.output_types: - fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg') - if 'err' in args.output_types: - err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg') - if 'ref' in args.output_types: - ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg') - - -# Conversion loop -with torch.no_grad(): - queue = Queue(1) - def load_worker(): - tgt_bgr = torch.tensor([120/255, 255/255, 155/255]).view(1, 3, 1, 1) - for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): - if args.video_target_bgr: - src, tgt_bgr = input_batch - else: - src = input_batch - queue.put((src, tgt_bgr)) - queue.put(None) - loader = Thread(target=load_worker) - loader.start() - # move background to device - bgr = (bgr[None]).to(device, non_blocking=False) - while True: - task = queue.get() - if task == None: - break - src, tgt_bgr = task - # move frame to device - src = src.to(device, non_blocking=True) - tgt_bgr = tgt_bgr.to(device, non_blocking=True) - - if args.model_type == 'mattingbase': - pha, fgr, err, _ = model(src, bgr) - elif args.model_type == 'mattingrefine': - pha, fgr, _, _, err, ref = model(src, bgr) - elif args.model_type == 'mattingbm': - pha, fgr = model(src, bgr) + # Prepare writers + if args.output_format == 'video': + h = args.video_resize[1] if args.video_resize is not None else vid.height + w = args.video_resize[0] if args.video_resize is not None else vid.width if 'com' in args.output_types: - if args.output_format == 'video': - # Output composite with green background - com = fgr * pha + tgt_bgr * (1 - pha) - com_writer.add_batch(com) - else: - # Output composite as rgba png images - com = torch.cat([fgr * pha.ne(0), pha], dim=1) - com_writer.add_batch(com) + com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h) if 'pha' in args.output_types: - pha_writer.add_batch(pha) + pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h) if 'fgr' in args.output_types: - fgr_writer.add_batch(fgr) + fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h) if 'err' in args.output_types: - err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) + err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h) if 'ref' in args.output_types: - ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) - # terminate children processes - loader.join() - if args.output_format == 'video': + ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h) + else: if 'com' in args.output_types: - com_writer.close() + com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png') if 'pha' in args.output_types: - pha_writer.close() + pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg') if 'fgr' in args.output_types: - fgr_writer.close() + fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg') if 'err' in args.output_types: - err_writer.close() + err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg') if 'ref' in args.output_types: - ref_writer.close() + ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg') + + + # Conversion loop + with torch.no_grad(): + queue = Queue(1) + def load_worker(): + tgt_bgr = torch.tensor([120/255, 255/255, 155/255]).view(1, 3, 1, 1) + for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): + if args.video_target_bgr: + src, tgt_bgr = input_batch + else: + src = input_batch + queue.put((src, tgt_bgr)) + queue.put(None) + loader = Thread(target=load_worker) + loader.start() + # move background to device + bgr = (bgr[None]).to(device, non_blocking=False) + while True: + task = queue.get() + if task == None: + break + src, tgt_bgr = task + # move frame to device + src = src.to(device, non_blocking=True) + tgt_bgr = tgt_bgr.to(device, non_blocking=True) + + if args.model_type == 'mattingbase': + pha, fgr, err, _ = model(src, bgr) + elif args.model_type == 'mattingrefine': + pha, fgr, _, _, err, ref = model(src, bgr) + elif args.model_type == 'mattingbm': + pha, fgr = model(src, bgr) + + if 'com' in args.output_types: + if args.output_format == 'video': + # Output composite with green background + com = fgr * pha + tgt_bgr * (1 - pha) + com_writer.add_batch(com) + else: + # Output composite as rgba png images + com = torch.cat([fgr * pha.ne(0), pha], dim=1) + com_writer.add_batch(com) + if 'pha' in args.output_types: + pha_writer.add_batch(pha) + if 'fgr' in args.output_types: + fgr_writer.add_batch(fgr) + if 'err' in args.output_types: + err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) + if 'ref' in args.output_types: + ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) + # terminate children processes + loader.join() + if args.output_format == 'video': + if 'com' in args.output_types: + com_writer.close() + if 'pha' in args.output_types: + pha_writer.close() + if 'fgr' in args.output_types: + fgr_writer.close() + if 'err' in args.output_types: + err_writer.close() + if 'ref' in args.output_types: + ref_writer.close()