-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
30 changed files
with
6,570 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,9 @@ data/FLAME2020/generic_model.pkl | |
train.py.bak | ||
test.py.bak | ||
|
||
__pycache__/ | ||
**/__pycache__/ | ||
|
||
|
||
|
||
|
||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import math | ||
import torch | ||
from torch import Tensor | ||
from torch.optim.optimizer import Optimizer | ||
from typing import List, Optional | ||
from diff_gaussian_rasterization._C import AdamUpdate | ||
|
||
''' | ||
Shengjie Ma | ||
2024/2/2 | ||
Accelerate Adam optimizer. We find it one of bottlenecks in training. | ||
Reference: torch/optim/adam.py | ||
Only for float32, contiguous, GPU only tensors | ||
''' | ||
|
||
def adamAccCore(params: List[Tensor], | ||
grads: List[Tensor], | ||
exp_avgs: List[Tensor], | ||
exp_avg_sqs: List[Tensor], | ||
max_exp_avg_sqs: List[Tensor], | ||
state_steps: List[int], | ||
*, | ||
amsgrad: bool, | ||
beta1: float, | ||
beta2: float, | ||
lr: float, | ||
weight_decay: float, | ||
eps: float): | ||
r"""Functional API that performs Adam algorithm computation. | ||
See :class:`~torch.optim.Adam` for details. | ||
""" | ||
|
||
assert(amsgrad == False) # Not supported features yet | ||
|
||
for i, param in enumerate(params): | ||
|
||
grad = grads[i] | ||
exp_avg = exp_avgs[i] | ||
exp_avg_sq = exp_avg_sqs[i] | ||
step = state_steps[i] | ||
|
||
bias_correction1 = 1 - beta1 ** step | ||
bias_correction2 = 1 - beta2 ** step | ||
sqrt_bias_correction2 = math.sqrt(bias_correction2) | ||
step_size = lr / bias_correction1 | ||
|
||
AdamUpdate( | ||
beta1, beta2, bias_correction1, sqrt_bias_correction2, | ||
step_size, eps, weight_decay, | ||
param, grad, exp_avg, exp_avg_sq | ||
) | ||
|
||
class AdamAcc(Optimizer): | ||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | ||
weight_decay=0, amsgrad=False): | ||
if not 0.0 <= lr: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if not 0.0 <= eps: | ||
raise ValueError("Invalid epsilon value: {}".format(eps)) | ||
if not 0.0 <= betas[0] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||
if not 0.0 <= betas[1] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||
if not 0.0 <= weight_decay: | ||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | ||
defaults = dict(lr=lr, betas=betas, eps=eps, | ||
weight_decay=weight_decay, amsgrad=amsgrad) | ||
super(AdamAcc, self).__init__(params, defaults) | ||
|
||
def __setstate__(self, state): | ||
super(AdamAcc, self).__setstate__(state) | ||
for group in self.param_groups: | ||
group.setdefault('amsgrad', False) | ||
|
||
@torch.no_grad() | ||
def step(self, closure=None): | ||
"""Performs a single optimization step. | ||
Args: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
params_with_grad = [] | ||
grads = [] | ||
exp_avgs = [] | ||
exp_avg_sqs = [] | ||
max_exp_avg_sqs = [] | ||
state_steps = [] | ||
beta1, beta2 = group['betas'] | ||
|
||
for p in group['params']: | ||
if p.grad is not None: | ||
params_with_grad.append(p) | ||
if p.grad.is_sparse: | ||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | ||
grads.append(p.grad) | ||
|
||
state = self.state[p] | ||
# Lazy state initialization | ||
if len(state) == 0: | ||
state['step'] = 0 | ||
# Exponential moving average of gradient values | ||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||
# Exponential moving average of squared gradient values | ||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||
if group['amsgrad']: | ||
# Maintains max of all exp. moving avg. of sq. grad. values | ||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | ||
|
||
exp_avgs.append(state['exp_avg']) | ||
exp_avg_sqs.append(state['exp_avg_sq']) | ||
|
||
if group['amsgrad']: | ||
max_exp_avg_sqs.append(state['max_exp_avg_sq']) | ||
|
||
# update the steps for each param group update | ||
state['step'] += 1 | ||
# record the step after step update | ||
state_steps.append(state['step']) | ||
|
||
adamAccCore(params_with_grad, | ||
grads, | ||
exp_avgs, | ||
exp_avg_sqs, | ||
max_exp_avg_sqs, | ||
state_steps, | ||
amsgrad=group['amsgrad'], | ||
beta1=beta1, | ||
beta2=beta2, | ||
lr=group['lr'], | ||
weight_decay=group['weight_decay'], | ||
eps=group['eps']) | ||
return loss |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import os | ||
import random | ||
|
||
from pytorch3d.renderer import PerspectiveCameras | ||
from pytorch3d.transforms import rotation_6d_to_matrix | ||
|
||
import numpy as np | ||
import torch | ||
import glob | ||
import cv2 | ||
|
||
|
||
class DummyObj: | ||
def __init__(self): | ||
pass | ||
|
||
class FaceDataset: | ||
def __init__(self, dataset_name, load_iterations = None, shuffle=True, resolution_scale=[1.0]): | ||
self.dataset_name = dataset_name | ||
file_list = glob.glob(os.path.join(dataset_name,"checkpoint","*.frame")) | ||
self.shuffle = shuffle | ||
self.n_frames = len(file_list) | ||
self.n_seg = 350 # use last 350 frames as test set | ||
self.n_extract_ratio = -1 | ||
train_ids = [] | ||
test_ids = [] | ||
for ii in range(self.n_frames): | ||
if ii + self.n_seg >= self.n_frames: | ||
test_ids.append(ii) | ||
else: | ||
train_ids.append(ii) | ||
self.train_ids = train_ids | ||
self.test_ids = test_ids | ||
if self.shuffle: | ||
random.shuffle(self.train_ids) | ||
random.shuffle(self.test_ids) | ||
self.output_list = None | ||
|
||
def getTrainCameras(self): | ||
return self.train_ids | ||
|
||
def getTestCameras(self): | ||
return self.test_ids | ||
|
||
def prepare_data(self, reside_image_on_gpu=True, device="cuda"): | ||
output_list = [] | ||
for ii in range(self.n_frames): | ||
output_list.append(self.getData(ii,reside_image_on_gpu,device)) | ||
self.output_list = output_list | ||
|
||
def getData(self, id, reside_image_on_gpu=True ,device="cuda"): | ||
|
||
if self.output_list is not None: | ||
return self.output_list[id] | ||
|
||
image = cv2.imread(os.path.join(self.dataset_name, "images/%05d.png" % id), cv2.IMREAD_UNCHANGED) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) | ||
image = np.asarray(image, dtype=np.float32) / 255. | ||
image_mask = image[..., -1] | ||
image = image[..., :3] | ||
|
||
image = torch.from_numpy(image) | ||
image_mask = torch.from_numpy(image_mask) | ||
if reside_image_on_gpu: | ||
image = image.to(device) | ||
image_mask = image_mask.to(device) | ||
|
||
H,W, _ = image.shape | ||
|
||
frame = torch.load(os.path.join(self.dataset_name,"checkpoint/%05d.frame" % id)) | ||
frame_flame = frame['flame'] | ||
frame_camera = frame['camera'] | ||
frame_opencv = frame['opencv'] | ||
|
||
camera = PerspectiveCameras( | ||
device = device, | ||
principal_point= torch.from_numpy(frame_camera['pp']).to(device), | ||
focal_length= torch.from_numpy(frame_camera['fl']).to(device), | ||
R = rotation_6d_to_matrix(torch.from_numpy(frame_camera['R']).to(device)), | ||
T = torch.from_numpy(frame_camera['t']).to(device), | ||
image_size = [[H,W]] | ||
) | ||
output = DummyObj() | ||
# BASE | ||
output.original_image = image | ||
output.mask = image_mask | ||
output.image_name = id | ||
# FLAME params | ||
output.cameras = camera | ||
output.image_size = [H,W] | ||
output.shape = torch.from_numpy(frame_flame['shape']).to(device) | ||
output.exp = torch.from_numpy(frame_flame['exp']).to(device) | ||
output.tex = torch.from_numpy(frame_flame['tex']).to(device) | ||
output.eyes = torch.from_numpy(frame_flame['eyes']).to(device) | ||
output.jaw = torch.from_numpy(frame_flame['jaw']).to(device) | ||
output.eyelids = torch.from_numpy(frame_flame['eyelids']).to(device) | ||
|
||
output.R = rotation_6d_to_matrix(torch.from_numpy(frame_camera['R']).to(device)) | ||
output.t = torch.from_numpy(frame_camera['t']).to(device) | ||
|
||
w2c = np.zeros([4,4]) | ||
w2c[3,3] = 1 | ||
w2c[:3,:3] = frame_opencv['R'][0] | ||
w2c[:3,3] = frame_opencv['t'][0] | ||
|
||
c2w = np.linalg.inv(w2c) | ||
|
||
t_w2c = torch.from_numpy(w2c.transpose()).float().to(device) | ||
t_c2w = torch.from_numpy(c2w.transpose()).float().to(device) | ||
|
||
znear = 0.01 | ||
zfar = 100.0 | ||
z_sign = 1.0 | ||
proj = np.zeros([4,4]) | ||
proj[0,:2] = frame_opencv['K'][0,0,:2] * 2. / W | ||
proj[1,:2] = frame_opencv['K'][0,1,:2] * 2. / H | ||
proj[0,2] = frame_opencv['K'][0,0,2] * 2. / W - 1. | ||
proj[1,2] = frame_opencv['K'][0,1,2] * 2. / H - 1. | ||
proj[3,2] = z_sign | ||
proj[2,2] = z_sign * zfar / (zfar - znear) | ||
proj[2,3] = -(zfar * znear) / (zfar - znear) | ||
|
||
proj_w2c = proj @ w2c | ||
t_proj_w2c = torch.from_numpy(proj_w2c.transpose()).float().to(device) | ||
|
||
output.FoVx = 2 * np.arctan(W * 0.5 / frame_opencv['K'][0,0,0]) | ||
output.FoVy = 2 * np.arctan(H * 0.5 / frame_opencv['K'][0,1,1]) | ||
output.image_height = H | ||
output.image_width = W | ||
output.world_view_transform = t_w2c.contiguous() | ||
output.full_proj_transform = t_proj_w2c.contiguous() | ||
output.camera_center = t_c2w[3,:3].contiguous() | ||
|
||
return output | ||
|
Oops, something went wrong.