From 2c93c2b17a5d11517ba00791334e5968dc070a19 Mon Sep 17 00:00:00 2001 From: Ke Date: Tue, 21 May 2024 16:47:30 +0800 Subject: [PATCH 1/2] Update eval_helper.py Use half precision to reduce memory usage. --- .../segmentation_refinement/eval_helper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/segmentation-refinement/segmentation_refinement/eval_helper.py b/segmentation-refinement/segmentation_refinement/eval_helper.py index 8d9bf37..a7f153d 100644 --- a/segmentation-refinement/segmentation_refinement/eval_helper.py +++ b/segmentation-refinement/segmentation_refinement/eval_helper.py @@ -30,11 +30,14 @@ def safe_forward(model, im, seg, inter_s8=None, inter_s4=None): p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1 p_inter_s8[:,:,0:ph,0:pw] = inter_s8 inter_s8 = p_inter_s8 + inter_s8 = inter_s8.half() if inter_s4 is not None: p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1 p_inter_s4[:,:,0:ph,0:pw] = inter_s4 inter_s4 = p_inter_s4 - + inter_s4 = inter_s4.half() + im = im.half() + seg = seg.half() images = model(im, seg, inter_s8, inter_s4) return_im = {} From ac9ce27efc0fb6d0b195185017d4d8956bf9b247 Mon Sep 17 00:00:00 2001 From: Ke Date: Tue, 21 May 2024 16:50:53 +0800 Subject: [PATCH 2/2] Update main.py Use half precision. --- segmentation-refinement/segmentation_refinement/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/segmentation-refinement/segmentation_refinement/main.py b/segmentation-refinement/segmentation_refinement/main.py index 9891d19..43b35b4 100644 --- a/segmentation-refinement/segmentation_refinement/main.py +++ b/segmentation-refinement/segmentation_refinement/main.py @@ -35,7 +35,8 @@ def __init__(self, device='cpu', model_folder=None, download_and_check_model=Tru new_dict[name] = v self.model.load_state_dict(new_dict) self.model.eval().to(device) - + self.model = self.model.half() + self.im_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(