Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MrGiovanni committed May 26, 2023
1 parent 8bd2533 commit f7612bb
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def process(args):
)

#Load pre-trained weights
model.load_params(torch.load(args.pretrain)["state_dict"])
if args.pretrain is not None:
model.load_params(torch.load(args.pretrain)["state_dict"])

if args.trans_encoding == 'word_embedding':
word_embedding = torch.load(args.word_embedding)
Expand Down Expand Up @@ -212,8 +213,8 @@ def main():
## model load
parser.add_argument('--backbone', default='unet', help='backbone [swinunetr or unet or dints or unetpp]')
parser.add_argument('--resume', default=None, help='The path resume from checkpoint')
parser.add_argument('--pretrain', default='./pretrained_weights/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt', #swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt
help='The path of pretrain model')
parser.add_argument('--pretrain', default=None, #swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt
help='The path of pretrain model. Eg, ./pretrained_weights/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt')
parser.add_argument('--trans_encoding', default='word_embedding',
help='the type of encoding: rand_embedding or word_embedding')
parser.add_argument('--word_embedding', default='./pretrained_weights/txt_encoding.pth',
Expand Down

0 comments on commit f7612bb

Please sign in to comment.