Skip to content

Commit

Permalink
use cond instead of inline if+lint
Browse files Browse the repository at this point in the history
  • Loading branch information
youssef62 committed Jan 7, 2025
1 parent 5f765c9 commit e602378
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/fused_adam_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ C10_DEVICE inline void adam_math(
if (grad_scale_ptr) {
r_args[kGradIdx][ii] = grad_to_store;
}
//don't write into gradients if beta1 is 0
if (beta1>0){
// don't write into gradients if beta1 is 0
if (beta1>0) {
r_args[kExpAvgIdx][ii] = exp_avg;
}
r_args[kExpAvgSqIdx][ii] = exp_avg_sq;
Expand Down
10 changes: 6 additions & 4 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,11 +2240,13 @@ def get_obj_size(d):
]

num_params = 4
size_of_param_in_bytes = (
32 * 16 * dtype.__sizeof__()
)
size_of_param_in_bytes = 32 * 16 * dtype.__sizeof__()
for optim_input in beta_1_optim_inputs:
zero = 0.0 if isinstance(optim_input.kwargs["betas"][0], float) else torch.tensor(0.0, device=device, dtype=dtype)
zero = (
0.0
if isinstance(optim_input.kwargs["betas"][0], float)
else torch.tensor(0.0, device=device, dtype=dtype)
)
beta1_values = (optim_input.kwargs["betas"][0], zero)
total_sizes = [] # will end up as [big_state_dict_size, no_exp_avg_sd_size]
for beta1 in beta1_values:
Expand Down
4 changes: 2 additions & 2 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def step(self, closure=None):
adam(
params_with_grad,
grads,
exp_avgs if beta1 > 0 else grads,
torch.cond(beta1 > 0, lambda: exp_avgs, lambda: grads),
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
Expand Down Expand Up @@ -433,7 +433,7 @@ def _single_tensor_adam(
device_beta1 = beta1

# Decay the first and second moment running average coefficient
if device_beta1 > 0:
if device_beta1 > 0:
exp_avg.lerp_(grad, 1 - device_beta1)

exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def optim_inputs_func_adam(device, dtype=None):
params=None,
kwargs={"betas": (0.0, 0.999)},
desc="zero-beta1",
)
),
]
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+ (mps_supported_configs if _get_device_type(device) == "mps" else [])
Expand Down

0 comments on commit e602378

Please sign in to comment.