Skip to content

Commit

Permalink
Merge pull request #5 from ye-luo/change-type
Browse files Browse the repository at this point in the history
Change myVarsFull_ to ValueType
  • Loading branch information
jptowns authored Nov 30, 2023
2 parents cbcea13 + 42329d4 commit ab955c4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 66 deletions.
81 changes: 19 additions & 62 deletions src/QMCWaveFunctions/RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,9 @@ void RotatedSPOs::resetParametersExclusive(const opt_variables_type& active)
const size_t N = m_act_rot_inds_.size();
std::vector<ValueType> delta_param(N);

size_t psize = N;
const size_t psize = use_global_rot_ ? m_full_rot_inds_.size() : N;
assert(psize >= N);

if (use_global_rot_)
{
psize = m_full_rot_inds_.size();
assert(psize >= N);
}

std::vector<ValueType> old_param(psize);
std::vector<ValueType> new_param(psize);

//cast ValueType to RealType for delta param
Expand All @@ -138,18 +132,9 @@ void RotatedSPOs::resetParametersExclusive(const opt_variables_type& active)

if (use_global_rot_)
{
//can do this trick since params are stored as p0_r, p0_i, p1_r, p1_i, ...
auto* old_param_data_real = (RealType*)old_param.data();
//cant use std::copy here because myVarsFull is pair type
for (int i = 0; i < myVarsFull_.size(); i++)
old_param_data_real[i] = myVarsFull_[i];

applyDeltaRotation(delta_param, old_param, new_param);

// Save the the params
auto* new_param_data_real = (RealType*)new_param.data();
for (int i = 0; i < myVarsFull_.size(); i++)
myVarsFull_[i] = new_param_data_real[i];
std::vector<ValueType> old_param(psize);
std::copy_n(myVarsFull_.data(), myVarsFull_.size(), old_param.data());
applyDeltaRotation(delta_param, old_param, myVarsFull_);
}
else
{
Expand All @@ -169,12 +154,7 @@ void RotatedSPOs::writeVariationalParameters(hdf_archive& hout)
std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName();

//Store in h5 as ValueType
std::vector<ValueType> full_params(m_full_rot_inds_.size());
auto* full_params_data_real = (RealType*)full_params.data();
for (int i = 0; i < myVarsFull_.size(); i++)
full_params_data_real[i] = myVarsFull_[i];

hout.write(full_params, rot_global_name);
hout.write(myVarsFull_, rot_global_name);
hout.pop();
}
else
Expand Down Expand Up @@ -223,35 +203,24 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)
if (grp_global_exists)
{
hin.push("rotation_global", false);
std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName();
const std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName();

std::vector<int> sizes(1);
if (!hin.getShape<RealType>(rot_global_name, sizes))
if (!hin.getShape<ValueType>(rot_global_name, sizes))
throw std::runtime_error("Failed to read rotation_global in VP file");

//h5 is storing std::vector<ComplexType>vec as {vec.size(), 2}
//So sizes[0] needs to be doubled if complex for storage into myVarsFull_
int nparam_full_actual = IsComplex_t<ValueType>::value ? 2 * sizes[0] : sizes[0];
int nparam_full = myVarsFull_.size();

if (nparam_full != nparam_full_actual)
//h5 is storing a std::vector<ComplexType> vec as {vec.size(), 2}
if (myVarsFull_.size() != sizes[0])
{
std::ostringstream tmp_err;
tmp_err << "Expected number of full rotation parameters (" << nparam_full << ") does not match number in file ("
<< nparam_full_actual << ")";
tmp_err << "Expected number of full rotation parameters (" << myVarsFull_.size()
<< ") does not match number in file (" << sizes[0] << ")";
throw std::runtime_error(tmp_err.str());
}
std::vector<ValueType> full_params(sizes[0]);
hin.read(full_params, rot_global_name);

//values stored as ValueType. Now unpack into reals
auto* full_params_data_real = (RealType*)full_params.data();
for (int i = 0; i < nparam_full; i++)
myVarsFull_[i] = full_params_data_real[i];

hin.read(myVarsFull_, rot_global_name);
hin.pop();

applyFullRotation(full_params, true);
applyFullRotation(myVarsFull_, true);
}
else if (grp_hist_exists)
{
Expand Down Expand Up @@ -352,7 +321,6 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota

// This will add the orbital rotation parameters to myVars
// and will also read in initial parameter values supplied in input file
int p, q;
int nparams_active = m_act_rot_inds_.size();

if (params_supplied_)
Expand All @@ -371,38 +339,27 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota

// If the user input parameters, use those. Otherwise, initialize the parameters to zero
if (params_supplied_)
{
RealType x = real_part ? std::real(params[i]) : std::imag(params[i]);
optvars.insert(sstr.str(), x);
}
optvars.insert(sstr.str(), real_part ? std::real(params[i]) : std::imag(params[i]));
else
{
optvars.insert(sstr.str(), 0.0);
}
};

myVars.clear();
for (int i = 0; i < nparams_active; i++)
{
p = m_act_rot_inds_[i].first;
q = m_act_rot_inds_[i].second;
const int p = m_act_rot_inds_[i].first;
const int q = m_act_rot_inds_[i].second;
registerParameter(i, p, q, myVars, params_, true);
if constexpr (IsComplex_t<ValueType>::value)
registerParameter(i, p, q, myVars, params_, false);
}

if (use_global_rot_)
{
myVarsFull_.clear();
const size_t N = m_full_rot_inds_.size();
myVarsFull_.resize(N);
for (int i = 0; i < N; i++)
{
p = m_full_rot_inds_[i].first;
q = m_full_rot_inds_[i].second;
registerParameter(i, p, q, myVarsFull_, params_, true);
if constexpr (IsComplex_t<ValueType>::value)
registerParameter(i, p, q, myVarsFull_, params_, false);
}
myVarsFull_[i] = params_supplied_ ? params_[i] : 0.0;
}

//Printing the parameters
Expand Down
6 changes: 3 additions & 3 deletions src/QMCWaveFunctions/RotatedSPOs.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace qmcplusplus
class RotatedSPOs;
namespace testing
{
opt_variables_type& getMyVarsFull(RotatedSPOs& rot);
std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot);
std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot);
} // namespace testing

Expand Down Expand Up @@ -450,7 +450,7 @@ class RotatedSPOs : public SPOSet, public OptimizableObject
std::vector<ValueType> params_;

/// Full set of rotation matrix parameters for use in global rotation method
opt_variables_type myVarsFull_;
std::vector<ValueType> myVarsFull_;

/// timer for apply_rotation
NewTimer& apply_rotation_timer_;
Expand All @@ -463,7 +463,7 @@ class RotatedSPOs : public SPOSet, public OptimizableObject
/// Use global rotation or history list
bool use_global_rot_ = true;

friend opt_variables_type& testing::getMyVarsFull(RotatedSPOs& rot);
friend std::vector<ValueType>& testing::getMyVarsFull(RotatedSPOs& rot);
friend std::vector<std::vector<ValueType>>& testing::getHistoryParams(RotatedSPOs& rot);
};

Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]")
namespace testing
{
opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; }
opt_variables_type& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull_; }
std::vector<QMCTraits::ValueType>& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull_; }
std::vector<std::vector<QMCTraits::ValueType>>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; }
} // namespace testing

Expand Down

0 comments on commit ab955c4

Please sign in to comment.