From 943ba2d139e411437b5e857a83836230e77d0b2a Mon Sep 17 00:00:00 2001 From: yangboz Date: Mon, 26 Dec 2022 10:11:46 +0800 Subject: [PATCH] added:notebook's python codebase files; --- point_e/examples/image2pointcloud.py | 71 ++++++++++++++++++++++++++++ point_e/examples/pointcloud2mesh.py | 62 ++++++++++++++++++++++++ point_e/examples/text2pointcloud.py | 71 ++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+) create mode 100644 point_e/examples/image2pointcloud.py create mode 100644 point_e/examples/pointcloud2mesh.py create mode 100644 point_e/examples/text2pointcloud.py diff --git a/point_e/examples/image2pointcloud.py b/point_e/examples/image2pointcloud.py new file mode 100644 index 0000000..a0eeb0e --- /dev/null +++ b/point_e/examples/image2pointcloud.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[ ]: + + +from PIL import Image +import torch +from tqdm.auto import tqdm + +from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config +from point_e.diffusion.sampler import PointCloudSampler +from point_e.models.download import load_checkpoint +from point_e.models.configs import MODEL_CONFIGS, model_from_config +from point_e.util.plotting import plot_point_cloud + + +# In[ ]: + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +print('creating base model...') +base_name = 'base40M' # use base300M or base1B for better results +base_model = model_from_config(MODEL_CONFIGS[base_name], device) +base_model.eval() +base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name]) + +print('creating upsample model...') +upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device) +upsampler_model.eval() +upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample']) + +print('downloading base checkpoint...') +base_model.load_state_dict(load_checkpoint(base_name, device)) + +print('downloading upsampler checkpoint...') +upsampler_model.load_state_dict(load_checkpoint('upsample', device)) + + +# In[ ]: + + +sampler = PointCloudSampler( + device=device, + models=[base_model, upsampler_model], + diffusions=[base_diffusion, upsampler_diffusion], + num_points=[1024, 4096 - 1024], + aux_channels=['R', 'G', 'B'], + guidance_scale=[3.0, 3.0], +) + + +# In[ ]: + + +# Load an image to condition on. +img = Image.open('example_data/cube_stack.jpg') + +# Produce a sample from the model. +samples = None +for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))): + samples = x + + +# In[ ]: + + +pc = sampler.output_to_point_clouds(samples)[0] +fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75))) + diff --git a/point_e/examples/pointcloud2mesh.py b/point_e/examples/pointcloud2mesh.py new file mode 100644 index 0000000..8fc6d0c --- /dev/null +++ b/point_e/examples/pointcloud2mesh.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[ ]: + + +from PIL import Image +import torch +import matplotlib.pyplot as plt +from tqdm.auto import tqdm + +from point_e.models.download import load_checkpoint +from point_e.models.configs import MODEL_CONFIGS, model_from_config +from point_e.util.pc_to_mesh import marching_cubes_mesh +from point_e.util.plotting import plot_point_cloud +from point_e.util.point_cloud import PointCloud + + +# In[ ]: + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +print('creating SDF model...') +name = 'sdf' +model = model_from_config(MODEL_CONFIGS[name], device) +model.eval() + +print('loading SDF model...') +model.load_state_dict(load_checkpoint(name, device)) + + +# In[ ]: + + +# Load a point cloud we want to convert into a mesh. +pc = PointCloud.load('example_data/pc_corgi.npz') + +# Plot the point cloud as a sanity check. +fig = plot_point_cloud(pc, grid_size=2) + + +# In[ ]: + + +# Produce a mesh (with vertex colors) +mesh = marching_cubes_mesh( + pc=pc, + model=model, + batch_size=4096, + grid_size=32, # increase to 128 for resolution used in evals + progress=True, +) + + +# In[ ]: + + +# Write the mesh to a PLY file to import into some other program. +with open('mesh.ply', 'wb') as f: + mesh.write_ply(f) + diff --git a/point_e/examples/text2pointcloud.py b/point_e/examples/text2pointcloud.py new file mode 100644 index 0000000..4bd06b4 --- /dev/null +++ b/point_e/examples/text2pointcloud.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[ ]: + + +import torch +from tqdm.auto import tqdm + +from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config +from point_e.diffusion.sampler import PointCloudSampler +from point_e.models.download import load_checkpoint +from point_e.models.configs import MODEL_CONFIGS, model_from_config +from point_e.util.plotting import plot_point_cloud + + +# In[ ]: + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +print('creating base model...') +base_name = 'base40M-textvec' +base_model = model_from_config(MODEL_CONFIGS[base_name], device) +base_model.eval() +base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name]) + +print('creating upsample model...') +upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device) +upsampler_model.eval() +upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample']) + +print('downloading base checkpoint...') +base_model.load_state_dict(load_checkpoint(base_name, device)) + +print('downloading upsampler checkpoint...') +upsampler_model.load_state_dict(load_checkpoint('upsample', device)) + + +# In[ ]: + + +sampler = PointCloudSampler( + device=device, + models=[base_model, upsampler_model], + diffusions=[base_diffusion, upsampler_diffusion], + num_points=[1024, 4096 - 1024], + aux_channels=['R', 'G', 'B'], + guidance_scale=[3.0, 0.0], + model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all +) + + +# In[ ]: + + +# Set a prompt to condition on. +prompt = 'a red motorcycle' + +# Produce a sample from the model. +samples = None +for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))): + samples = x + + +# In[ ]: + + +pc = sampler.output_to_point_clouds(samples)[0] +fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75))) +