-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use default device obtained from architecture #165
Conversation
bd30dd9
to
9d23279
Compare
if desired_device == "multi-cuda" and torch.cuda.device_count() < 1: | ||
raise ValueError( | ||
"Requested device 'multi-gpu' or 'multi-cuda', but found only one CUDA " | ||
"device. If you want to run on a single GPU, please use 'gpu' or " | ||
"'cuda' instead." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
< 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
< 2
. Good catch.
|
||
|
||
def default_device(_root_: BaseContainer) -> str: | ||
"""Custom OmegaConf resolver to find the default precision of an architecture. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default device (not dtype), I also think this is not doing what the docstring says. This is picking the device for training, not finding the default device of the architecture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I just copied the docstring from the other function and forgot to change. Should be changed in the latest version.
Otherwise, I've tested the code a bit and it works |
9d23279
to
28121aa
Compare
Similar to #155 within this PR the device is dynamically selected based on the architecture and the available devices on the system. This only applied if the user did not specify a device manually in the
options.yaml
.To achieve this, I slightly changed the API of
pick_devices
. Therequested_device
parameter is now optional and renamed todesired_device
. If this parameter is not given the device is solely picked based on the list ofarchitecture_devices
and devices available on the current system. I also made theget_available_devices
function private because it is only of usage within thepick_devices
function.📚 Documentation preview 📚: https://metatensor-models--165.org.readthedocs.build/en/165/