Skip to content

Commit

Permalink
Tweak handling of target_device in cli.select_head
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Jan 7, 2025
1 parent 5d6e60d commit ad1719a
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,33 @@ def main():
action="store_true",
help="list names of the heads",
)
parser.add_argument(
"--target_device",
"-d",
help="target device, defaults to model's current device",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
help="name for output model, defaults to model.head_name, followed by .target_device if specified",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

model = torch.load(args.model_file)

if args.list_heads:
print("Available heads:")
print("\n".join([" " + h for h in model.heads]))
else:

if args.output_file is None:
args.output_file = args.model_file + "." + args.head_name + "." + str(next(model.parameters()).device)
args.output_file = args.model_file + "." + args.head_name + ("." + args.target_device if (args.target_device is not None) else "")

model_single = remove_pt_head(model, args.head_name)
if args.target_device is not None:
target_device = str(next(model.parameters()).device)
model_single.to(target_device)
torch.save(model_single, args.output_file)


Expand Down

0 comments on commit ad1719a

Please sign in to comment.