-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_helper.py
49 lines (38 loc) · 1.49 KB
/
utils_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gc
import torch
import numpy as np
from torchvision.transforms import v2 as T
def get_random_color():
return tuple(np.random.randint(0, 256, 3))
def clear_cuda():
torch.cuda.empty_cache()
def garbage_collect():
gc.collect()
def get_transform(train):
transforms = []
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.RandomVerticalFlip(0.5))
transforms.append(T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2))
transforms.append(T.ToDtype(torch.float, scale=True))
transforms.append(T.ToPureTensor())
return T.Compose(transforms)
def save_checkpoint(model, optimizer, epoch, checkpoint_path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, checkpoint_path)
def load_checkpoint(checkpoint_path, device, model, optimizer):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
return epoch
def load_model_checkpoint(checkpoint_path, model, device='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
return model
def print_gpu_memory_usage():
allocated = torch.cuda.memory_allocated()
print(f"Allocated GPU memory: {allocated / 1024**2:.2f} MB")