-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
27 lines (20 loc) · 1.01 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
IMG_SIZE = (384, 384)
image_transform = transforms.Compose([transforms.Resize(IMG_SIZE),
transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2),transforms.RandomPerspective(distortion_scale=0.2),], p=0.3),
transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2),transforms.RandomAffine(degrees=10),], p=0.3),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomHorizontalFlip(p=0.3),
transforms.ToTensor(),
# transforms.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD),
])
valid_transform = transforms.Compose([transforms.Resize(IMG_SIZE),
transforms.ToTensor()])
vits_name = "resnet50"
cnn_name = "vit_base_patch16_384"
seed = 16122004
num_class = 2