Skip to content

Commit

Permalink
update training code
Browse files Browse the repository at this point in the history
  • Loading branch information
zjumsj committed Aug 13, 2024
1 parent 77e3a20 commit a0ab7c0
Show file tree
Hide file tree
Showing 30 changed files with 6,570 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ data/FLAME2020/generic_model.pkl
train.py.bak
test.py.bak

__pycache__/
**/__pycache__/




Expand Down
391 changes: 391 additions & 0 deletions FLAME/FLAME.py

Large diffs are not rendered by default.

141 changes: 141 additions & 0 deletions FLAME/adamacc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import math
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from typing import List, Optional
from diff_gaussian_rasterization._C import AdamUpdate

'''
Shengjie Ma
2024/2/2
Accelerate Adam optimizer. We find it one of bottlenecks in training.
Reference: torch/optim/adam.py
Only for float32, contiguous, GPU only tensors
'''

def adamAccCore(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
"""

assert(amsgrad == False) # Not supported features yet

for i, param in enumerate(params):

grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]

bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
sqrt_bias_correction2 = math.sqrt(bias_correction2)
step_size = lr / bias_correction1

AdamUpdate(
beta1, beta2, bias_correction1, sqrt_bias_correction2,
step_size, eps, weight_decay,
param, grad, exp_avg, exp_avg_sq
)

class AdamAcc(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamAcc, self).__init__(params, defaults)

def __setstate__(self, state):
super(AdamAcc, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
beta1, beta2 = group['betas']

for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
grads.append(p.grad)

state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])

if group['amsgrad']:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])

# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])

adamAccCore(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=group['amsgrad'],
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'])
return loss
Binary file added FLAME/blendshapes/l_eyelid.npy
Binary file not shown.
Binary file added FLAME/blendshapes/r_eyelid.npy
Binary file not shown.
135 changes: 135 additions & 0 deletions FLAME/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
import random

from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.transforms import rotation_6d_to_matrix

import numpy as np
import torch
import glob
import cv2


class DummyObj:
def __init__(self):
pass

class FaceDataset:
def __init__(self, dataset_name, load_iterations = None, shuffle=True, resolution_scale=[1.0]):
self.dataset_name = dataset_name
file_list = glob.glob(os.path.join(dataset_name,"checkpoint","*.frame"))
self.shuffle = shuffle
self.n_frames = len(file_list)
self.n_seg = 350 # use last 350 frames as test set
self.n_extract_ratio = -1
train_ids = []
test_ids = []
for ii in range(self.n_frames):
if ii + self.n_seg >= self.n_frames:
test_ids.append(ii)
else:
train_ids.append(ii)
self.train_ids = train_ids
self.test_ids = test_ids
if self.shuffle:
random.shuffle(self.train_ids)
random.shuffle(self.test_ids)
self.output_list = None

def getTrainCameras(self):
return self.train_ids

def getTestCameras(self):
return self.test_ids

def prepare_data(self, reside_image_on_gpu=True, device="cuda"):
output_list = []
for ii in range(self.n_frames):
output_list.append(self.getData(ii,reside_image_on_gpu,device))
self.output_list = output_list

def getData(self, id, reside_image_on_gpu=True ,device="cuda"):

if self.output_list is not None:
return self.output_list[id]

image = cv2.imread(os.path.join(self.dataset_name, "images/%05d.png" % id), cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = np.asarray(image, dtype=np.float32) / 255.
image_mask = image[..., -1]
image = image[..., :3]

image = torch.from_numpy(image)
image_mask = torch.from_numpy(image_mask)
if reside_image_on_gpu:
image = image.to(device)
image_mask = image_mask.to(device)

H,W, _ = image.shape

frame = torch.load(os.path.join(self.dataset_name,"checkpoint/%05d.frame" % id))
frame_flame = frame['flame']
frame_camera = frame['camera']
frame_opencv = frame['opencv']

camera = PerspectiveCameras(
device = device,
principal_point= torch.from_numpy(frame_camera['pp']).to(device),
focal_length= torch.from_numpy(frame_camera['fl']).to(device),
R = rotation_6d_to_matrix(torch.from_numpy(frame_camera['R']).to(device)),
T = torch.from_numpy(frame_camera['t']).to(device),
image_size = [[H,W]]
)
output = DummyObj()
# BASE
output.original_image = image
output.mask = image_mask
output.image_name = id
# FLAME params
output.cameras = camera
output.image_size = [H,W]
output.shape = torch.from_numpy(frame_flame['shape']).to(device)
output.exp = torch.from_numpy(frame_flame['exp']).to(device)
output.tex = torch.from_numpy(frame_flame['tex']).to(device)
output.eyes = torch.from_numpy(frame_flame['eyes']).to(device)
output.jaw = torch.from_numpy(frame_flame['jaw']).to(device)
output.eyelids = torch.from_numpy(frame_flame['eyelids']).to(device)

output.R = rotation_6d_to_matrix(torch.from_numpy(frame_camera['R']).to(device))
output.t = torch.from_numpy(frame_camera['t']).to(device)

w2c = np.zeros([4,4])
w2c[3,3] = 1
w2c[:3,:3] = frame_opencv['R'][0]
w2c[:3,3] = frame_opencv['t'][0]

c2w = np.linalg.inv(w2c)

t_w2c = torch.from_numpy(w2c.transpose()).float().to(device)
t_c2w = torch.from_numpy(c2w.transpose()).float().to(device)

znear = 0.01
zfar = 100.0
z_sign = 1.0
proj = np.zeros([4,4])
proj[0,:2] = frame_opencv['K'][0,0,:2] * 2. / W
proj[1,:2] = frame_opencv['K'][0,1,:2] * 2. / H
proj[0,2] = frame_opencv['K'][0,0,2] * 2. / W - 1.
proj[1,2] = frame_opencv['K'][0,1,2] * 2. / H - 1.
proj[3,2] = z_sign
proj[2,2] = z_sign * zfar / (zfar - znear)
proj[2,3] = -(zfar * znear) / (zfar - znear)

proj_w2c = proj @ w2c
t_proj_w2c = torch.from_numpy(proj_w2c.transpose()).float().to(device)

output.FoVx = 2 * np.arctan(W * 0.5 / frame_opencv['K'][0,0,0])
output.FoVy = 2 * np.arctan(H * 0.5 / frame_opencv['K'][0,1,1])
output.image_height = H
output.image_width = W
output.world_view_transform = t_w2c.contiguous()
output.full_proj_transform = t_proj_w2c.contiguous()
output.camera_center = t_c2w[3,:3].contiguous()

return output

Loading

0 comments on commit a0ab7c0

Please sign in to comment.