Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose mixed_precision dtype arguments #348

Merged
merged 3 commits into from
May 21, 2024

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented May 20, 2024

Stack from ghstack (oldest at bottom):

add training.mixed_precision_param and .mixed_precision_reduce options

refactor a util to map strings to torch dtypes

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 20, 2024
@wconstab wconstab mentioned this pull request May 20, 2024
[ghstack-poisoned]
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this.

TORCH_DTYPE_ARGS = [
"checkpoint.export_dtype",
"training.mixed_precision_param",
"training.mixed_precision_reduce",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should "reduce" be "grad"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, im following the existing naming for the mixed_precision config struct

Comment on lines +244 to +245
torch dtype to use for reductions when applying mixed precision via FSDP.
This feature only takes effect when data_parallel_degree > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "reductions" -> "gradients"

Comment on lines 403 to 405
for k_, v_ in v.items():
if ".".join([k, k_]) in TORCH_DTYPE_ARGS:
v[k_] = torch_dtype(v_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment please?

[ghstack-poisoned]
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@wconstab wconstab merged commit a4f1d9d into gh/wconstab/27/base May 21, 2024
4 checks passed
wconstab added a commit that referenced this pull request May 21, 2024
add training.mixed_precision_param and .mixed_precision_reduce options

refactor a util to map strings to torch dtypes

ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796
Pull Request resolved: #348
@wconstab wconstab deleted the gh/wconstab/27/head branch May 21, 2024 01:03
tianyu-l pushed a commit that referenced this pull request May 28, 2024
add training.mixed_precision_param and .mixed_precision_reduce options

refactor a util to map strings to torch dtypes

ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796
Pull Request resolved: #348
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
add training.mixed_precision_param and .mixed_precision_reduce options

refactor a util to map strings to torch dtypes

ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796
Pull Request resolved: pytorch#348
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
add training.mixed_precision_param and .mixed_precision_reduce options

refactor a util to map strings to torch dtypes

ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796
Pull Request resolved: #348
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
add training.mixed_precision_param and .mixed_precision_reduce options

refactor a util to map strings to torch dtypes

ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796
Pull Request resolved: pytorch#348
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants