Skip to content

Commit

Permalink
pseudo_label
Browse files Browse the repository at this point in the history
  • Loading branch information
ljwztc committed Sep 20, 2023
1 parent 67577b9 commit c8e829e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 38 deletions.
Binary file modified dataset/__pycache__/dataloader.cpython-37.pyc
Binary file not shown.
47 changes: 47 additions & 0 deletions dataset/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,53 @@ def get_loader(args):
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4, collate_fn=list_data_collate)
return test_loader, val_transforms


def get_loader_without_gt(args):
val_transforms = Compose(
[
LoadImaged(keys=["image"]),
AddChanneld(keys=["image"]),
Orientationd(keys=["image"], axcodes="RAS"),
# ToTemplatelabeld(keys=['label']),
# RL_Splitd(keys=['label']),
Spacingd(
keys=["image"],
pixdim=(args.space_x, args.space_y, args.space_z),
mode=("bilinear"),
), # process h5 to here
ScaleIntensityRanged(
keys=["image"],
a_min=args.a_min,
a_max=args.a_max,
b_min=args.b_min,
b_max=args.b_max,
clip=True,
),
CropForegroundd(keys=["image"], source_key="image"),
ToTensord(keys=["image"]),
]
)

## test dict part
test_img = []
test_name = []
for item in args.dataset_list:
for line in open(args.data_txt_path + item +'_test.txt'):
name = line.strip().split()[1].split('.')[0]
test_img.append(args.data_root_path + line.strip().split()[0])
test_name.append(name)
data_dicts_test = [{'image': image, 'name': name}
for image, name in zip(test_img, test_name)]
print('test len {}'.format(len(data_dicts_test)))

if args.cache_dataset:
test_dataset = CacheDataset(data=data_dicts_test, transform=val_transforms, cache_rate=args.cache_rate)
else:
test_dataset = Dataset(data=data_dicts_test, transform=val_transforms)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4, collate_fn=list_data_collate)
return test_loader, val_transforms


if __name__ == "__main__":
train_loader, test_loader = partial_label_dataloader()
for index, item in enumerate(test_loader):
Expand Down
79 changes: 42 additions & 37 deletions pred_pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from monai.inferers import sliding_window_inference

from model.Universal_model import Universal_model
from dataset.dataloader import get_loader
from dataset.dataloader import get_loader_without_gt
from utils import loss
from utils.utils import dice_score, threshold_organ, visualize_label, merge_label, get_key
from utils.utils import TEMPLATE, ORGAN_NAME, NUM_CLASS
Expand All @@ -25,7 +25,7 @@


def validation(model, ValLoader, val_transforms, args):
save_dir = 'out/' + args.log_name + f'/pesudolbl_{args.epoch}'
save_dir = 'out/' + args.log_name #+ f'/pesudolbl_{args.epoch}'
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
os.mkdir(save_dir+'/predict')
Expand Down Expand Up @@ -53,17 +53,17 @@ def validation(model, ValLoader, val_transforms, args):
content = 'case%s| '%(name[b])
template_key = get_key(name[b])
organ_list = TEMPLATE[template_key]
pred_hard_post = organ_post_process(pred_hard.numpy(), organ_list)
pred_hard_post = organ_post_process(pred_hard.numpy(), organ_list, args.log_name+'/'+name[0].split('/')[0]+'/'+name[0].split('/')[-1],args)
pred_hard_post = torch.tensor(pred_hard_post)

for organ in organ_list:
if torch.sum(label[b,organ-1,:,:,:].cuda()) != 0:
dice_organ, recall, precision = dice_score(pred_hard_post[b,organ-1,:,:,:].cuda(), label[b,organ-1,:,:,:].cuda())
dice_list[template_key][0][organ-1] += dice_organ.item()
dice_list[template_key][1][organ-1] += 1
content += '%s: %.4f, '%(ORGAN_NAME[organ-1], dice_organ.item())
print('%s: dice %.4f, recall %.4f, precision %.4f.'%(ORGAN_NAME[organ-1], dice_organ.item(), recall.item(), precision.item()))
print(content)
# for organ in organ_list:
# if torch.sum(label[b,organ-1,:,:,:].cuda()) != 0:
# dice_organ, recall, precision = dice_score(pred_hard_post[b,organ-1,:,:,:].cuda(), label[b,organ-1,:,:,:].cuda())
# dice_list[template_key][0][organ-1] += dice_organ.item()
# dice_list[template_key][1][organ-1] += 1
# content += '%s: %.4f, '%(ORGAN_NAME[organ-1], dice_organ.item())
# print('%s: dice %.4f, recall %.4f, precision %.4f.'%(ORGAN_NAME[organ-1], dice_organ.item(), recall.item(), precision.item()))
# print(content)


### testing phase for this function
Expand All @@ -76,27 +76,27 @@ def validation(model, ValLoader, val_transforms, args):

ave_organ_dice = np.zeros((2, NUM_CLASS))

with open('out/'+args.log_name+f'/test_{args.epoch}.txt', 'w') as f:
for key in TEMPLATE.keys():
organ_list = TEMPLATE[key]
content = 'Task%s| '%(key)
for organ in organ_list:
dice = dice_list[key][0][organ-1] / dice_list[key][1][organ-1]
content += '%s: %.4f, '%(ORGAN_NAME[organ-1], dice)
ave_organ_dice[0][organ-1] += dice_list[key][0][organ-1]
ave_organ_dice[1][organ-1] += dice_list[key][1][organ-1]
print(content)
f.write(content)
f.write('\n')
content = 'Average | '
for i in range(NUM_CLASS):
content += '%s: %.4f, '%(ORGAN_NAME[i], ave_organ_dice[0][i] / ave_organ_dice[1][i])
print(content)
f.write(content)
f.write('\n')
print(np.mean(ave_organ_dice[0] / ave_organ_dice[1]))
f.write('%s: %.4f, '%('average', np.mean(ave_organ_dice[0] / ave_organ_dice[1])))
f.write('\n')
# with open('out/'+args.log_name+f'/test_{args.epoch}.txt', 'w') as f:
# for key in TEMPLATE.keys():
# organ_list = TEMPLATE[key]
# content = 'Task%s| '%(key)
# for organ in organ_list:
# dice = dice_list[key][0][organ-1] / dice_list[key][1][organ-1]
# content += '%s: %.4f, '%(ORGAN_NAME[organ-1], dice)
# ave_organ_dice[0][organ-1] += dice_list[key][0][organ-1]
# ave_organ_dice[1][organ-1] += dice_list[key][1][organ-1]
# print(content)
# f.write(content)
# f.write('\n')
# content = 'Average | '
# for i in range(NUM_CLASS):
# content += '%s: %.4f, '%(ORGAN_NAME[i], ave_organ_dice[0][i] / ave_organ_dice[1][i])
# print(content)
# f.write(content)
# f.write('\n')
# print(np.mean(ave_organ_dice[0] / ave_organ_dice[1]))
# f.write('%s: %.4f, '%('average', np.mean(ave_organ_dice[0] / ave_organ_dice[1])))
# f.write('\n')


# np.save(save_dir + '/result.npy', dice_list)
Expand All @@ -115,11 +115,12 @@ def main():
parser.add_argument("--device")
parser.add_argument("--epoch", default=0)
## logging
parser.add_argument('--log_name', default='Nvidia/old_fold0', help='The path resume from checkpoint')
parser.add_argument('--log_name', default='Nvidia', help='The path resume from checkpoint')
## model load
parser.add_argument('--resume', default='./out/Nvidia/old_fold0/aepoch_500.pth', help='The path resume from checkpoint')
parser.add_argument('--resume', default='./pretrained_weights/swinunetr.pth', help='The path resume from checkpoint')
parser.add_argument('--pretrain', default='./pretrained_weights/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt',
help='The path of pretrain model')
parser.add_argument('--backbone', default='swinunetr', help='backbone [swinunetr or unet]')
## hyperparameter
parser.add_argument('--max_epoch', default=1000, type=int, help='Number of training epoches')
parser.add_argument('--store_num', default=10, type=int, help='Store model how often')
Expand All @@ -133,7 +134,7 @@ def main():
### PAOT_123457891213: include 1 2 3 4 5 7 8 9 12 13
### PAOT_10_inner: same with NVIDIA for comparison
### PAOT_10: original division
parser.add_argument('--data_root_path', default='/home/jliu288/data/whole_organ/', help='data root path')
parser.add_argument('--data_root_path', default='/computenodes/node31/team1/jliu/data/ct_data/', help='data root path')
parser.add_argument('--data_txt_path', default='./dataset/dataset_list/', help='data txt path')
parser.add_argument('--batch_size', default=1, type=int, help='batch size')
parser.add_argument('--num_workers', default=8, type=int, help='workers numebr for DataLoader')
Expand Down Expand Up @@ -174,7 +175,11 @@ def main():
# args.epoch = checkpoint['epoch']

for key, value in load_dict.items():
name = '.'.join(key.split('.')[1:])
if 'swinViT' in key or 'encoder' in key or 'decoder' in key:
name = '.'.join(key.split('.')[1:])
name = 'backbone.' + name
else:
name = '.'.join(key.split('.')[1:])
store_dict[name] = value

model.load_state_dict(store_dict)
Expand All @@ -184,7 +189,7 @@ def main():

torch.backends.cudnn.benchmark = True

test_loader, val_transforms = get_loader(args)
test_loader, val_transforms = get_loader_without_gt(args)

validation(model, test_loader, val_transforms, args)

Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def main():
### PAOT: include PAOT_123457891213 and PAOT_10
### PAOT_123457891213: include 1 2 3 4 5 7 8 9 12 13
### PAOT_10_inner
parser.add_argument('--data_root_path', default='DATA_ROOT', help='data root path')
parser.add_argument('--data_root_path', default='/computenodes/node31/team1/jliu/data/ct_data/', help='data root path')
parser.add_argument('--data_txt_path', default='./dataset/dataset_list/', help='data txt path')
parser.add_argument('--batch_size', default=1, type=int, help='batch size')
parser.add_argument('--num_workers', default=8, type=int, help='workers numebr for DataLoader')
Expand Down

0 comments on commit c8e829e

Please sign in to comment.