-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvnext_softmax.py
143 lines (110 loc) · 4.78 KB
/
convnext_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import pandas as pd
import torch.optim as optim
import torch
import os
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from configs.base_config import config
from utils.data_related import data_split, get_dataloader
from transforms.convnext_transform import ConvnextTransform
from transforms.sketch_transform_develop import SketchTransform
from dataset.dataset import CustomDataset
from models.convnext_model import Convnext_Model
from losses.cross_entropy_loss import CrossEntropyLoss
from trainers.cv_trainer import Trainer
from utils.inference import load_model, inference_convnext
from losses.Focal_Loss import FocalLoss
def main():
train_info = pd.read_csv(config.train_data_info_file_path)
train_transform = SketchTransform(is_train=True)
train_dataset = CustomDataset(config.train_data_dir_path,
train_info,
train_transform)
model = Convnext_Model(model_name = "convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320", num_classes = 500, pretrained = True)
model.to(config.device)
optimizer = optim.Adam(
model.parameters(),
lr=config.lr
)
loss_fn = CrossEntropyLoss()
trainer = Trainer(
model=model,
device=config.device,
train_dataset=train_dataset, # 전체 학습 데이터셋
val_dataset=None, # 검증용으로도 동일한 전체 학습 데이터셋 사용
optimizer=optimizer,
scheduler=None,
loss_fn=loss_fn,
epochs=15,
result_path=config.save_result_path,
n_splits=5 # K-Fold의 K 값, 예를 들어 5로 설정
)
trainer.train_with_cv()
def ensemble_inference(models, device, test_loader):
"""
여러 모델을 사용하여 앙상블 예측을 수행하는 함수.
:param models: 불러온 모델들의 리스트.
:param device: 사용할 장치 (CPU 또는 GPU).
:param test_loader: 테스트 데이터 로더.
:return: 앙상블 예측 결과.
"""
# 모델을 평가 모드로 설정
for model in models:
model.to(device)
model.eval()
all_predictions = []
# 각 모델에서의 예측 수행
with torch.no_grad():
for images in tqdm(test_loader, desc="Ensembling"):
if len(images.shape) == 5: # (num_channels, height, width) 형식인 경우
images = images.squeeze(1)
images = images.to(device)
# 각 모델에 대해 예측 수행
model_outputs = []
for model in models:
logits = model(images)
probs = F.softmax(logits, dim=1) # 확률로 변환
model_outputs.append(probs.cpu().numpy()) # 각 모델의 예측을 numpy로 변환
# 모델들의 예측값 평균 (soft voting)
avg_output = np.mean(model_outputs, axis=0) # 모델들의 예측 평균
all_predictions.append(avg_output)
all_softmax_probs = np.vstack(all_softmax_probs)
return all_softmax_probs
def test():
test_info = pd.read_csv(config.test_data_info_file_path)
test_transform = ConvnextTransform(is_train=False)
test_dataset = CustomDataset(config.test_data_dir_path,
test_info,
test_transform,
is_inference=True)
test_loader = get_dataloader(test_dataset,
batch_size=config.batch_size,
shuffle=config.test_shuffle,
drop_last=False)
# 각 fold에서 저장된 모델 경로
model_paths = [
os.path.join(config.save_result_path, f'fold_{i}_best_model.pt') for i in range(5)
]
# 각 fold에서 저장된 모델을 불러와 리스트에 추가
models = []
for path in model_paths:
model = Convnext_Model(model_name="convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320", num_classes=500, pretrained=False)
model.load_state_dict(load_model(config.save_result_path, os.path.basename(path)))
models.append(model)
# 장치 설정 (GPU 사용 가능 시 GPU 사용)
device = config.device
softmax_probs = ensemble_inference(models, device, test_loader)
softmax_df = pd.DataFrame(softmax_probs)
softmax_df.to_csv("softmax_probabilities.csv", index=False)
# 앙상블 예측 수행
predictions = ensemble_inference(models, device, test_loader)
# 결과 저장
test_info['target'] = predictions
test_info = test_info.reset_index().rename(columns={"index": "ID"})
test_info.to_csv("output_ensemble.csv", index=False)
if __name__ == "__main__":
#main()
test()