-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
allow to set mxied_precision in accelerate config when using a given deepspeed config #3385
base: main
Are you sure you want to change the base?
allow to set mxied_precision in accelerate config when using a given deepspeed config #3385
Conversation
@@ -1335,7 +1335,6 @@ def _deepspeed_config_checks(self): | |||
"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", | |||
"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", | |||
"ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", | |||
"ACCELERATE_MIXED_PRECISION", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ds config doesn't have such an attribute, and it only works when calling(https://github.com/huggingface/accelerate/blob/main/src/accelerate/accelerator.py#L351)
accelerate/src/accelerate/utils/dataclasses.py
Lines 1256 to 1284 in 81d8a03
def set_mixed_precision(self, mixed_precision): | |
ds_config = self.deepspeed_config | |
kwargs = { | |
"fp16.enabled": mixed_precision == "fp16", | |
# When training in fp8, we still rely on bf16 autocast for the core mixed precision | |
"bf16.enabled": mixed_precision in ("bf16", "fp8"), | |
} | |
if mixed_precision == "fp16": | |
if "fp16" not in ds_config: | |
ds_config["fp16"] = {"enabled": True, "auto_cast": True} | |
elif mixed_precision in ("bf16", "fp8"): | |
if "bf16" not in ds_config: | |
ds_config["bf16"] = {"enabled": True} | |
if mixed_precision == "fp8" and self.enable_msamp: | |
if "msamp" not in ds_config: | |
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level} | |
if mixed_precision != "no": | |
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" | |
if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true": | |
raise ValueError( | |
f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file." | |
) | |
for dtype in ["fp16", "bf16"]: | |
if dtype not in ds_config: | |
ds_config[dtype] = {"enabled": False} | |
self.fill_match("fp16.enabled", must_match=False, **kwargs) | |
self.fill_match("bf16.enabled", must_match=False, **kwargs) |
![image](https://private-user-images.githubusercontent.com/16217777/410843758-d9014fc0-192c-4731-b520-cb927f5846b6.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkzNjI3NDUsIm5iZiI6MTczOTM2MjQ0NSwicGF0aCI6Ii8xNjIxNzc3Ny80MTA4NDM3NTgtZDkwMTRmYzAtMTkyYy00NzMxLWI1MjAtY2I5MjdmNTg0NmI2LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTIlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEyVDEyMTQwNVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWY2YmM5OWUzZDBlODJhYWExMWFmNmZkMTkwNTU4NDMwZGE2ZDNiMmI5ODhmYzA1MjZmZGQ0MzFiNzg0ZDgyYmEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.qNXsnfmIhRPwO806zV48mv0vScaq8iSC9qh24hdv1Wc)
7552c3e
to
be09963
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
be09963
to
a43dec6
Compare
What does this PR do?
Fixes #3360 (comment).
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.