Skip to content

Commit

Permalink
Fixed device choice in mtt eval
Browse files Browse the repository at this point in the history
  • Loading branch information
abmazitov committed Nov 6, 2024
1 parent 5a60116 commit add538d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def _eval_targets(
# Infer the device and dtype from the model
model_tensor = next(itertools.chain(model.parameters(), model.buffers()))
dtype = model_tensor.dtype
device = "cuda" if "cuda" in model.capabilities().supported_devices else "cpu"
device = "cpu"
if torch.cuda.is_available() and "cuda" in model.capabilities().supported_devices:
device = "cuda"

Check warning on line 189 in src/metatrain/cli/eval.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/eval.py#L189

Added line #L189 was not covered by tests
logger.info(f"Running on device {device} with dtype {dtype}")
model.to(dtype=dtype, device=device)

Expand Down

0 comments on commit add538d

Please sign in to comment.