-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_c4.py
48 lines (44 loc) · 1.28 KB
/
train_c4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torchvision
from torchvision.transforms import ToTensor, Lambda,Compose,Resize,Normalize
import matplotlib.pyplot as plt
from functools import reduce
from models.equivariant import *
from groupconv.groups import *
from utils.train import count_params, train,val
from data.loader import get_loader
import os
BS=16
EPOCHS=15
NCLASSES = 10
NAME = "C4SimpleCNN_FashionMNIST"
print(f"MODEL NAME:{NAME}")
trainloader = get_loader("train",BS,NCLASSES)
valloader = get_loader("val",BS,NCLASSES)
testloader = get_loader("test",BS,NCLASSES)
m = EqSimpleCNN(CyclicGroup(4),4,NCLASSES)
print(m)
print(f"N PARAMS: {count_params(m)}")
m = m.cuda()
if not os.path.exists(f"ckpt/{NAME}.pth"):
train(
trainloader,
m,
torch.nn.CrossEntropyLoss(),
torch.optim.Adam(m.parameters()),
device="cuda",
epochs=EPOCHS,
val_loader=valloader,
model_name=NAME
)
else:
for aug in [None,"c4","d4"]:
testloader = get_loader("test",BS,NCLASSES,augment=aug)
m.load_state_dict(torch.load(f"ckpt/{NAME}.pth"))
acc,loss = val(
testloader,
m,
torch.nn.CrossEntropyLoss(),
"cuda"
)
print(f"Augmentation: {aug}, Test Acc:{acc*100:.2f}% Test Loss:{loss:>7f}")