Skip to content

Commit

Permalink
Update test script for VAE, other minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
saihv committed Jan 7, 2022
1 parent c85db06 commit 9dcb75c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 60 deletions.
24 changes: 24 additions & 0 deletions event_vae/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ def get_event_batch(self, start_idx):

event_batch_np = np.asarray(event_stack, dtype=np.float32)
return event_batch_np

def get_event_batch_idx(self, start_idx, size):
event_stack = []

t_start = self.events[start_idx][0]
t_final = self.events[start_idx + size - 1][0]

dt = t_final - t_start
idx = start_idx
# Iterate over events for a window of dt
while idx - start_idx < size:
e_curr = deepcopy(self.events[idx])
event_stack.append(e_curr)

if dt > 0:
t_relative = float(t_final - e_curr[0]) / dt
else:
t_relative = 0
event_stack[idx - start_idx][0] = t_relative

idx += 1

event_batch_np = np.asarray(event_stack, dtype=np.float32)
return event_batch_np

def get_event_timeslice(self, start_idx, dt_max=16000):
"""
Expand Down
90 changes: 32 additions & 58 deletions event_vae/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,27 @@
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils import data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torch.nn.functional as F
from copy import deepcopy
import matplotlib.pyplot as plt
import time
import loss

from torch.utils.tensorboard import SummaryWriter

from event_ae import EventAE
from chamfer import ChamferDistance, ChamferLoss
from dataloader import EventStreamDataset, EventStreamArray
from event_vae import EventVAE
from dataloader import EventStreamArray
from data_utils import *

parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
type=str,
default="data/MSTrain_bytestream.txt",
default="data/gates.txt",
help="training data filename",
)
parser.add_argument("--batch_size", type=int,
default=1000, help="input batch size")
parser.add_argument("--batch_num", type=int, default=50,
parser.add_argument("--batch_num", type=int, default=1,
help="number of batches")
parser.add_argument("--data_len", type=int, default=2,
parser.add_argument("--data_len", type=int, default=3,
help="event element length")
parser.add_argument("--tcode", type=bool, default=False,
help="consider timestamps")
Expand All @@ -53,64 +40,41 @@
parser.add_argument(
"--decoder", type=str, default="image", help="decoder type: stream or image"
)
parser.add_argument("--outf", type=str, default="weights",
parser.add_argument("--output_dir", type=str, default="weights",
help="output folder")
parser.add_argument("--model", type=str, default="", help="model path")
parser.add_argument("--model", type=str, default="weights/evae_racing.pt", help="model path")
parser.add_argument(
"--norm_type",
type=str,
default="none",
help="normalization type: scale: [0, 1]; center: [-1, 1]",
)
parser.add_argument("--arch", type=str, default="vanilla")

opt = parser.parse_args()
print(opt)


def blue(x): return "\033[94m" + x + "\033[0m"


opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

try:
os.makedirs(opt.outf)
except OSError:
pass

writer = SummaryWriter("runs/str_to_img_test")

# Params:
# n_events, data_size for stream decoder
# Height, width for image decoder

H = 32
W = 32
H = 64
W = 64
params = [H, W]

event_dataset = EventStreamArray(
opt.input_file, opt.batch_num, opt.batch_size, opt.data_len
)

"""
batch_size_total = opt.batch_size * opt.batch_num
train_loader = data.DataLoader(
event_dataset,
batch_size=batch_size_total,
shuffle=False,
num_workers=0,
drop_last=True,
)
"""
data_utils = EventDataUtils(H, W, opt.norm_type)

data_utils = EventDataUtils(32, 32, opt.norm_type)
enet = EventVAE(
opt.data_len, opt.latent_size, params, decoder=opt.decoder)

enet = EventAE(
opt.data_len, opt.latent_size, params, decoder=opt.decoder, norm_type=opt.norm_type
)
if opt.model != "":
enet.load_state_dict(torch.load(opt.model))
enet.cuda()
Expand All @@ -124,22 +88,32 @@ def blue(x): return "\033[94m" + x + "\033[0m"
else:
pol = False

step_size = 50
start_idx = 0

event_dataset = EventStreamArray(opt.input_file, opt.batch_num, step_size, opt.data_len)

with torch.no_grad():
for i, data in enumerate(test_loader, 0):
# events = data_utils.normalize(EventExtractor(data, batch_num=1))
idx = random.randint(0, len(event_dataset.events) - 200000)

idx = random.randint(0, 1000000)
events = data_utils.normalize(event_array.get_event_stack(idx))
events = Variable(events)
while step_size < 4000:
event_data = event_dataset.get_event_batch_idx(idx, step_size)
gt = data_utils.create_frame(event_data, pol, opt.tcode)
gt = torch.from_numpy(gt).cuda()
events, timestamps = event_dataset.extract(event_data.reshape(1, step_size, 4))
events = events.transpose(2, 1)
events = events.cuda()

recon, z = enet(events)
recon, z, mu, logvar = enet(events)

events = events.transpose(2, 1).contiguous()
data_utils.ax1.cla()
data_utils.ax2.cla()

if opt.decoder == "stream":
recon = recon.transpose(2, 1).contiguous()
z = z.cpu().numpy()

data_utils.compare_frames(events, recon)
fig = data_utils.compare_frames(gt, recon)
# data_utils.ax3.scatter(z[0][0], z[0][1], z[0][2], 20)
# data_utils.ax3.set_aspect("equal", adjustable="box")
plt.draw()
plt.pause(0.001)
step_size += 50
3 changes: 1 addition & 2 deletions event_vae/train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
parser.add_argument(
"--decoder", type=str, default="image", help="decoder type: stream or image"
)
parser.add_argument("--outf", type=str, default="weights", help="output folder")
parser.add_argument("--model", type=str, default="", help="model path")
parser.add_argument(
"--norm_type",
Expand All @@ -57,7 +56,7 @@
torch.manual_seed(opt.manualSeed)

try:
os.makedirs(opt.outf)
os.makedirs(opt.output_dir)
except OSError:
pass

Expand Down

0 comments on commit 9dcb75c

Please sign in to comment.