diff --git a/point_e/diffusion/gaussian_diffusion.py b/point_e/diffusion/gaussian_diffusion.py index 8dc4dc3..4fc9589 100644 --- a/point_e/diffusion/gaussian_diffusion.py +++ b/point_e/diffusion/gaussian_diffusion.py @@ -7,6 +7,7 @@ import numpy as np import torch as th +from point_e.util.precision_compatibility import NP_FLOAT32_64, TH_FLOAT32_64 def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): @@ -16,7 +17,7 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time See get_named_beta_schedule() for the new library of schedules. """ if beta_schedule == "linear": - betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=NP_FLOAT32_64) else: raise NotImplementedError(beta_schedule) assert betas.shape == (num_diffusion_timesteps,) @@ -159,8 +160,8 @@ def __init__( self.channel_scales = channel_scales self.channel_biases = channel_biases - # Use float64 for accuracy. - betas = np.array(betas, dtype=np.float64) + # using float64 (when available) for accuracy + betas = np.array(betas, dtype=NP_FLOAT32_64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" assert (betas > 0).all() and (betas <= 1).all() @@ -1013,7 +1014,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = th.from_numpy(arr).to(dtype=TH_FLOAT32_64, device=timesteps.device)[timesteps].to(TH_FLOAT32_64) while len(res.shape) < len(broadcast_shape): res = res[..., None] return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/point_e/evals/feature_extractor.py b/point_e/evals/feature_extractor.py index 2a64597..2828207 100644 --- a/point_e/evals/feature_extractor.py +++ b/point_e/evals/feature_extractor.py @@ -14,6 +14,8 @@ def get_torch_devices() -> List[Union[str, torch.device]]: if torch.cuda.is_available(): return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())] + if torch.backends.mps.is_available(): + return [torch.device("mps")] else: return ["cpu"] diff --git a/point_e/examples/image2pointcloud.ipynb b/point_e/examples/image2pointcloud.ipynb index f698ac8..bc169af 100644 --- a/point_e/examples/image2pointcloud.ipynb +++ b/point_e/examples/image2pointcloud.ipynb @@ -23,7 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')\n", "\n", "print('creating base model...')\n", "base_name = 'base40M' # use base300M or base1B for better results\n", diff --git a/point_e/examples/pointcloud2mesh.ipynb b/point_e/examples/pointcloud2mesh.ipynb index b2b591c..7220438 100644 --- a/point_e/examples/pointcloud2mesh.ipynb +++ b/point_e/examples/pointcloud2mesh.ipynb @@ -24,7 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')\n", "\n", "print('creating SDF model...')\n", "name = 'sdf'\n", diff --git a/point_e/examples/text2pointcloud.ipynb b/point_e/examples/text2pointcloud.ipynb index 22785d5..8837b44 100644 --- a/point_e/examples/text2pointcloud.ipynb +++ b/point_e/examples/text2pointcloud.ipynb @@ -22,7 +22,7 @@ "metadata": {}, "outputs": [], "source": [ - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')\n", "\n", "print('creating base model...')\n", "base_name = 'base40M-textvec'\n", diff --git a/point_e/util/precision_compatibility.py b/point_e/util/precision_compatibility.py new file mode 100644 index 0000000..d1ca8cf --- /dev/null +++ b/point_e/util/precision_compatibility.py @@ -0,0 +1,5 @@ +import torch +import numpy as np + +NP_FLOAT32_64 = np.float32 if torch.backends.mps.is_available() else np.float64 +TH_FLOAT32_64 = torch.float32 if torch.backends.mps.is_available() else torch.float64