diff --git a/backend/src/nodes/impl/pytorch/pix_transform/auto_split.py b/backend/src/nodes/impl/pytorch/pix_transform/auto_split.py index 0265d60e6..24067ee18 100644 --- a/backend/src/nodes/impl/pytorch/pix_transform/auto_split.py +++ b/backend/src/nodes/impl/pytorch/pix_transform/auto_split.py @@ -33,6 +33,12 @@ def split(self, tile_size: Size) -> Size: return size, size +def _as_3d(img: np.ndarray) -> np.ndarray: + if img.ndim == 3: + return img + return np.expand_dims(img, axis=2) + + def pix_transform_auto_split( source: np.ndarray, guide: np.ndarray, @@ -63,7 +69,7 @@ def upscale(tile: np.ndarray, region: Region): try: tile_guide = region.scale(scale).read_from(guide) pix_op = to_op(PixTransform)( - guide_img=np.transpose(tile_guide, (2, 0, 1)), + guide_img=np.transpose(_as_3d(tile_guide), (2, 0, 1)), device=device, params=params, )