Skip to content

Commit

Permalink
Refactor: Use memory_op to set diag_const_nums (deepmodeling#5246)
Browse files Browse the repository at this point in the history
* Use memory_op to set consts

* Fix segfault

* Pyabacus test

* Revert changes

* Modify pointer usage

* I will win

* Malloc test

* Initialize with nullptr

* Remove useless code

---------

Co-authored-by: Haozhi Han <[email protected]>
  • Loading branch information
Critsium-xy and haozhihan authored Oct 18, 2024
1 parent 08ea40c commit 545d2ec
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
45 changes: 39 additions & 6 deletions source/module_hsolver/diag_const_nums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,60 @@ template class const_nums<std::complex<float>>;

// Specialize templates to support double types
template <>
const_nums<double>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
const_nums<double>::const_nums()
{
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = 0.0;
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = 1.0;
base_device::memory::resize_memory_op<double, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = -1.0;
}

// Specialize templates to support double types
template <>
const_nums<float>::const_nums() : zero(0.0), one(1.0), neg_one(-1.0)
const_nums<float>::const_nums()
{
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = 0.0;
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = 1.0;
base_device::memory::resize_memory_op<float, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = -1.0;
}

// Specialized templates to support std:: complex<double>types
template <>
const_nums<std::complex<double>>::const_nums()
: zero(std::complex<double>(0.0, 0.0)), one(std::complex<double>(1.0, 0.0)),
neg_one(std::complex<double>(-1.0, 0.0))
{
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = std::complex<double>(0.0, 0.0);
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = std::complex<double>(1.0, 0.0);
base_device::memory::resize_memory_op<std::complex<double>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = std::complex<double>(-1.0, 0.0);
}

// Specialized templates to support std:: complex<float>types
template <>
const_nums<std::complex<float>>::const_nums()
: zero(std::complex<float>(0.0, 0.0)), one(std::complex<float>(1.0, 0.0)), neg_one(std::complex<float>(-1.0, 0.0))
{
}
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->zero, 1);
this->zero[0] = std::complex<float>(0.0, 0.0);
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->one, 1);
this->one[0] = std::complex<float>(1.0, 0.0);
base_device::memory::resize_memory_op<std::complex<float>, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->neg_one, 1);
this->neg_one[0] = std::complex<float>(-1.0, 0.0);
}
8 changes: 5 additions & 3 deletions source/module_hsolver/diag_const_nums.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#ifndef DIAG_CONST_NUMS
#define DIAG_CONST_NUMS
#include "module_base/module_device/memory_op.h"

template <typename T>
struct const_nums
{
const_nums();
T zero;
T one;
T neg_one;
base_device::DEVICE_CPU* cpu_ctx = {};
T* zero = nullptr;
T* one = nullptr;
T* neg_one = nullptr;
};

#endif
18 changes: 9 additions & 9 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
{
this->device = base_device::get_device_type<Device>(this->ctx);

this->one = &this->cs.one;
this->zero = &this->cs.zero;
this->neg_one = &this->cs.neg_one;
this->one = this->cs.one;
this->zero = this->cs.zero;
this->neg_one = this->cs.neg_one;

assert(david_ndim_in > 1);
assert(david_ndim_in * nband_in < nbasis_in * this->diag_comm.nproc);
Expand Down Expand Up @@ -534,8 +534,8 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,
}
else
{
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, cs.zero));
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, cs.zero));
std::vector<std::vector<T>> h_diag(nbase, std::vector<T>(nbase, cs.zero[0]));
std::vector<std::vector<T>> s_diag(nbase, std::vector<T>(nbase, cs.zero[0]));

for (size_t i = 0; i < nbase; i++)
{
Expand Down Expand Up @@ -564,10 +564,10 @@ void Diago_DavSubspace<T, Device>::diag_zhegvx(const int& nbase,

for (size_t j = nbase; j < this->nbase_x; j++)
{
hcc[i * this->nbase_x + j] = cs.zero;
hcc[j * this->nbase_x + i] = cs.zero;
scc[i * this->nbase_x + j] = cs.zero;
scc[j * this->nbase_x + i] = cs.zero;
hcc[i * this->nbase_x + j] = cs.zero[0];
hcc[j * this->nbase_x + i] = cs.zero[0];
scc[i * this->nbase_x + j] = cs.zero[0];
scc[j * this->nbase_x + i] = cs.zero[0];
}
}
}
Expand Down

0 comments on commit 545d2ec

Please sign in to comment.