Skip to content

Commit

Permalink
[Dygraph] Fix EagerReducer::MarkVarReady() 's lank of HasGrad() bra…
Browse files Browse the repository at this point in the history
…nch (PaddlePaddle#62299)

* fix eagr reducer

* Update reducer.cc

* fix approve error
  • Loading branch information
chen2016013 authored Mar 4, 2024
1 parent 437293b commit fc3fb05
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions paddle/fluid/distributed/collective/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,23 +831,33 @@ void EagerReducer::MarkVarReady(const size_t var_index,
auto &group_tensor = group.dense_tensors_[inside_group_index];
const auto length = group.length_[inside_group_index];
if (is_used_var) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
paddle::Tensor grad_tensor =
static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
if (grad_tensor.is_dense_tensor()) {
const auto &tensor_impl = grad_tensor.impl();
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_impl);
if (!dense_tensor->meta().is_contiguous()) {
grad_tensor.set_impl(std::make_shared<phi::DenseTensor>(std::move(
paddle::experimental::Trans2Contiguous(*dense_tensor))));
if (HasGrad(var_index)) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
paddle::Tensor grad_tensor =
static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
if (grad_tensor.is_dense_tensor()) {
const auto &tensor_impl = grad_tensor.impl();
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_impl);
if (!dense_tensor->meta().is_contiguous()) {
grad_tensor.set_impl(std::make_shared<phi::DenseTensor>(std::move(
paddle::experimental::Trans2Contiguous(*dense_tensor))));
}
}
}

group_tensor
.ShareDataWith(*(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
.Resize({grad_tensor.numel()});
group_tensor
.ShareDataWith(*(std::dynamic_pointer_cast<phi::DenseTensor>(
grad_tensor.impl())))
.Resize({grad_tensor.numel()});
} else {
VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad";
auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)});
dev_ctx->Alloc(&group_tensor, group.dtype_);
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0f);
}
} else {
// TODO(shenliang03): maybe save the memory by avoiding tensor
// construction
Expand Down

0 comments on commit fc3fb05

Please sign in to comment.