Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use default device obtained from architecture #165

Merged
merged 2 commits into from
Apr 11, 2024
Merged

Conversation

PicoCentauri
Copy link
Contributor

@PicoCentauri PicoCentauri commented Apr 3, 2024

Similar to #155 within this PR the device is dynamically selected based on the architecture and the available devices on the system. This only applied if the user did not specify a device manually in the options.yaml.

To achieve this, I slightly changed the API of pick_devices. The requested_device parameter is now optional and renamed to desired_device. If this parameter is not given the device is solely picked based on the list of architecture_devices and devices available on the current system. I also made the get_available_devices function private because it is only of usage within the pick_devices function.


📚 Documentation preview 📚: https://metatensor-models--165.org.readthedocs.build/en/165/

@PicoCentauri PicoCentauri changed the title Default device Use default device obtained from architecture Apr 3, 2024
@PicoCentauri PicoCentauri force-pushed the default-device branch 3 times, most recently from bd30dd9 to 9d23279 Compare April 3, 2024 15:06
Comment on lines 69 to 74
if desired_device == "multi-cuda" and torch.cuda.device_count() < 1:
raise ValueError(
"Requested device 'multi-gpu' or 'multi-cuda', but found only one CUDA "
"device. If you want to run on a single GPU, please use 'gpu' or "
"'cuda' instead."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

< 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

< 2. Good catch.



def default_device(_root_: BaseContainer) -> str:
"""Custom OmegaConf resolver to find the default precision of an architecture.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default device (not dtype), I also think this is not doing what the docstring says. This is picking the device for training, not finding the default device of the architecture

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just copied the docstring from the other function and forgot to change. Should be changed in the latest version.

@frostedoyster
Copy link
Collaborator

Otherwise, I've tested the code a bit and it works

@frostedoyster frostedoyster merged commit 5548548 into main Apr 11, 2024
11 checks passed
@frostedoyster frostedoyster deleted the default-device branch April 11, 2024 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants