Skip to content

Commit

Permalink
Access fused_adamw operator
Browse files Browse the repository at this point in the history
  • Loading branch information
zyf654321 committed Aug 29, 2024
1 parent 67467e9 commit e70717e
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions diopi_test/python/conformance/customized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def fused_adamw(
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
fused=True,
)
return params, exp_avgs, exp_avg_sqs, max_exp_avg_sqs

Expand Down
2 changes: 1 addition & 1 deletion diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3179,7 +3179,7 @@ def fused_adamw(
for state_step in state_steps:
c_state_steps.append(TensorP(state_step))

func = check_function("diopiFused_AdamW")
func = check_function("diopiFusedAdamW")
ret = func(
params[0].context(),
list(c_params),
Expand Down
2 changes: 1 addition & 1 deletion impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2815,7 +2815,7 @@ diopiError_t diopiMeshGrid(diopiContextHandle_t ctx, diopiTensorHandle_t* outs,
return diopiSuccess;
}

diopiError_t diopiFused_AdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps, int64_t nums,
float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize) {
impl::aten::setCurStream(ctx);
Expand Down
2 changes: 1 addition & 1 deletion proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2469,7 +2469,7 @@ DIOPI_API diopiError_t diopiReciprocalInp(diopiContextHandle_t ctx, diopiTensorH
* @param[in] amsgrad whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`.
* @param[in] maximize maximize the objective with respect to the params, instead of minimizing.
*/
DIOPI_API diopiError_t diopiFused_AdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
DIOPI_API diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps,
int64_t nums, float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize);

Expand Down

0 comments on commit e70717e

Please sign in to comment.