Skip to content

Commit

Permalink
weightload
Browse files Browse the repository at this point in the history
  • Loading branch information
ljwztc committed Sep 18, 2023
1 parent 41f6525 commit 67577b9
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
out
pretrained_weights/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt
pretrained_weights/Genesis_Chest_CT.pt
pretrained_weights/swinunetr.pth
Binary file modified dataset/__pycache__/dataloader.cpython-37.pyc
Binary file not shown.
16 changes: 11 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,25 @@ def main():
parser.add_argument("--device")
parser.add_argument("--epoch", default=0)
## logging
parser.add_argument('--log_name', default='Nvidia/old_fold0', help='The path resume from checkpoint')
parser.add_argument('--log_name', default='inference', help='The path resume from checkpoint')
## model load
parser.add_argument('--resume', default='./out/Nvidia/old_fold0/aepoch_500.pth', help='The path resume from checkpoint')
parser.add_argument('--resume', default='./pretrained_weights/swinunetr.pth', help='The path resume from checkpoint')
parser.add_argument('--pretrain', default='./pretrained_weights/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt',
help='The path of pretrain model')
parser.add_argument('--backbone', default='swinunetr', help='backbone [swinunetr or unet]')
## hyperparameter
parser.add_argument('--max_epoch', default=1000, type=int, help='Number of training epoches')
parser.add_argument('--store_num', default=10, type=int, help='Store model how often')
parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate')
parser.add_argument('--weight_decay', default=1e-5, type=float, help='Weight Decay')

## dataset
parser.add_argument('--dataset_list', nargs='+', default=['PAOT_123457891213', 'PAOT_10_inner']) # 'PAOT', 'felix'
parser.add_argument('--dataset_list', nargs='+', default=['PAOT_123457891213']) # 'PAOT', 'felix'
### please check this argment carefully
### PAOT: include PAOT_123457891213 and PAOT_10
### PAOT_123457891213: include 1 2 3 4 5 7 8 9 12 13
### PAOT_10_inner
parser.add_argument('--data_root_path', default='/home/jliu288/data/whole_organ/', help='data root path')
parser.add_argument('--data_root_path', default='DATA_ROOT', help='data root path')
parser.add_argument('--data_txt_path', default='./dataset/dataset_list/', help='data txt path')
parser.add_argument('--batch_size', default=1, type=int, help='batch size')
parser.add_argument('--num_workers', default=8, type=int, help='workers numebr for DataLoader')
Expand Down Expand Up @@ -185,9 +186,14 @@ def main():
# args.epoch = checkpoint['epoch']

for key, value in load_dict.items():
name = '.'.join(key.split('.')[1:])
if 'swinViT' in key or 'encoder' in key or 'decoder' in key:
name = '.'.join(key.split('.')[1:])
name = 'backbone.' + name
else:
name = '.'.join(key.split('.')[1:])
store_dict[name] = value


model.load_state_dict(store_dict)
print('Use pretrained weights')

Expand Down
Binary file modified utils/__pycache__/loss.cpython-37.pyc
Binary file not shown.
Binary file modified utils/__pycache__/utils.cpython-37.pyc
Binary file not shown.
6 changes: 5 additions & 1 deletion validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def main():
load_dict = torch.load(store_path)['net']

for key, value in load_dict.items():
name = '.'.join(key.split('.')[1:])
if 'swinViT' in key or 'encoder' in key or 'decoder' in key:
name = '.'.join(key.split('.')[1:])
name = 'backbone.' + name
else:
name = '.'.join(key.split('.')[1:])
store_dict[name] = value

model.load_state_dict(store_dict)
Expand Down

0 comments on commit 67577b9

Please sign in to comment.