From 63eb137a8e31cc28c8eef7c3509b04c3c39ba46e Mon Sep 17 00:00:00 2001 From: Ye Luo Date: Wed, 29 Nov 2023 15:15:04 -0600 Subject: [PATCH 1/2] Change myVarsFull_ to ValueType --- src/QMCWaveFunctions/RotatedSPOs.cpp | 76 +++++-------------- src/QMCWaveFunctions/RotatedSPOs.h | 6 +- .../tests/test_RotatedSPOs.cpp | 2 +- 3 files changed, 25 insertions(+), 59 deletions(-) diff --git a/src/QMCWaveFunctions/RotatedSPOs.cpp b/src/QMCWaveFunctions/RotatedSPOs.cpp index 310395cac0..18fac54595 100644 --- a/src/QMCWaveFunctions/RotatedSPOs.cpp +++ b/src/QMCWaveFunctions/RotatedSPOs.cpp @@ -114,15 +114,9 @@ void RotatedSPOs::resetParametersExclusive(const opt_variables_type& active) const size_t N = m_act_rot_inds_.size(); std::vector delta_param(N); - size_t psize = N; + const size_t psize = use_global_rot_ ? m_full_rot_inds_.size() : N; + assert(psize >= N); - if (use_global_rot_) - { - psize = m_full_rot_inds_.size(); - assert(psize >= N); - } - - std::vector old_param(psize); std::vector new_param(psize); //cast ValueType to RealType for delta param @@ -138,18 +132,9 @@ void RotatedSPOs::resetParametersExclusive(const opt_variables_type& active) if (use_global_rot_) { - //can do this trick since params are stored as p0_r, p0_i, p1_r, p1_i, ... - auto* old_param_data_real = (RealType*)old_param.data(); - //cant use std::copy here because myVarsFull is pair type - for (int i = 0; i < myVarsFull_.size(); i++) - old_param_data_real[i] = myVarsFull_[i]; - - applyDeltaRotation(delta_param, old_param, new_param); - - // Save the the params - auto* new_param_data_real = (RealType*)new_param.data(); - for (int i = 0; i < myVarsFull_.size(); i++) - myVarsFull_[i] = new_param_data_real[i]; + std::vector old_param(psize); + std::copy_n(myVarsFull_.data(), myVarsFull_.size(), old_param.data()); + applyDeltaRotation(delta_param, old_param, myVarsFull_); } else { @@ -169,12 +154,7 @@ void RotatedSPOs::writeVariationalParameters(hdf_archive& hout) std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName(); //Store in h5 as ValueType - std::vector full_params(m_full_rot_inds_.size()); - auto* full_params_data_real = (RealType*)full_params.data(); - for (int i = 0; i < myVarsFull_.size(); i++) - full_params_data_real[i] = myVarsFull_[i]; - - hout.write(full_params, rot_global_name); + hout.write(myVarsFull_, rot_global_name); hout.pop(); } else @@ -223,35 +203,24 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin) if (grp_global_exists) { hin.push("rotation_global", false); - std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName(); + const std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName(); std::vector sizes(1); - if (!hin.getShape(rot_global_name, sizes)) + if (!hin.getShape(rot_global_name, sizes)) throw std::runtime_error("Failed to read rotation_global in VP file"); - //h5 is storing std::vectorvec as {vec.size(), 2} - //So sizes[0] needs to be doubled if complex for storage into myVarsFull_ - int nparam_full_actual = IsComplex_t::value ? 2 * sizes[0] : sizes[0]; - int nparam_full = myVarsFull_.size(); - - if (nparam_full != nparam_full_actual) + //h5 is storing a std::vector vec as {vec.size(), 2} + if (myVarsFull_.size() != sizes[0]) { std::ostringstream tmp_err; - tmp_err << "Expected number of full rotation parameters (" << nparam_full << ") does not match number in file (" - << nparam_full_actual << ")"; + tmp_err << "Expected number of full rotation parameters (" << myVarsFull_.size() + << ") does not match number in file (" << sizes[0] << ")"; throw std::runtime_error(tmp_err.str()); } - std::vector full_params(sizes[0]); - hin.read(full_params, rot_global_name); - - //values stored as ValueType. Now unpack into reals - auto* full_params_data_real = (RealType*)full_params.data(); - for (int i = 0; i < nparam_full; i++) - myVarsFull_[i] = full_params_data_real[i]; - + hin.read(myVarsFull_, rot_global_name); hin.pop(); - applyFullRotation(full_params, true); + applyFullRotation(myVarsFull_, true); } else if (grp_hist_exists) { @@ -371,14 +340,9 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota // If the user input parameters, use those. Otherwise, initialize the parameters to zero if (params_supplied_) - { - RealType x = real_part ? std::real(params[i]) : std::imag(params[i]); - optvars.insert(sstr.str(), x); - } + optvars.insert(sstr.str(), real_part ? std::real(params[i]) : std::imag(params[i])); else - { optvars.insert(sstr.str(), 0.0); - } }; myVars.clear(); @@ -393,15 +357,17 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota if (use_global_rot_) { - myVarsFull_.clear(); const size_t N = m_full_rot_inds_.size(); + myVarsFull_.resize(N); for (int i = 0; i < N; i++) { p = m_full_rot_inds_[i].first; q = m_full_rot_inds_[i].second; - registerParameter(i, p, q, myVarsFull_, params_, true); - if constexpr (IsComplex_t::value) - registerParameter(i, p, q, myVarsFull_, params_, false); + + if (params_supplied_) + myVarsFull_[i] = params_[i]; + else + myVarsFull_[i] = 0.0; } } diff --git a/src/QMCWaveFunctions/RotatedSPOs.h b/src/QMCWaveFunctions/RotatedSPOs.h index 5ab9c31055..062407703f 100644 --- a/src/QMCWaveFunctions/RotatedSPOs.h +++ b/src/QMCWaveFunctions/RotatedSPOs.h @@ -21,7 +21,7 @@ namespace qmcplusplus class RotatedSPOs; namespace testing { -opt_variables_type& getMyVarsFull(RotatedSPOs& rot); +std::vector& getMyVarsFull(RotatedSPOs& rot); std::vector>& getHistoryParams(RotatedSPOs& rot); } // namespace testing @@ -450,7 +450,7 @@ class RotatedSPOs : public SPOSet, public OptimizableObject std::vector params_; /// Full set of rotation matrix parameters for use in global rotation method - opt_variables_type myVarsFull_; + std::vector myVarsFull_; /// timer for apply_rotation NewTimer& apply_rotation_timer_; @@ -463,7 +463,7 @@ class RotatedSPOs : public SPOSet, public OptimizableObject /// Use global rotation or history list bool use_global_rot_ = true; - friend opt_variables_type& testing::getMyVarsFull(RotatedSPOs& rot); + friend std::vector& testing::getMyVarsFull(RotatedSPOs& rot); friend std::vector>& testing::getHistoryParams(RotatedSPOs& rot); }; diff --git a/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp b/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp index e98c9bdf5e..9a4062581f 100644 --- a/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp +++ b/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp @@ -730,7 +730,7 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]") namespace testing { opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; } -opt_variables_type& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull_; } +std::vector& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull_; } std::vector>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; } } // namespace testing From 42329d461a3d8bddec87668bb73a6d2c812fa612 Mon Sep 17 00:00:00 2001 From: Ye Luo Date: Wed, 29 Nov 2023 15:37:53 -0600 Subject: [PATCH 2/2] Minor simplification. --- src/QMCWaveFunctions/RotatedSPOs.cpp | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/QMCWaveFunctions/RotatedSPOs.cpp b/src/QMCWaveFunctions/RotatedSPOs.cpp index 18fac54595..c47d85a4be 100644 --- a/src/QMCWaveFunctions/RotatedSPOs.cpp +++ b/src/QMCWaveFunctions/RotatedSPOs.cpp @@ -321,7 +321,6 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota // This will add the orbital rotation parameters to myVars // and will also read in initial parameter values supplied in input file - int p, q; int nparams_active = m_act_rot_inds_.size(); if (params_supplied_) @@ -348,8 +347,8 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota myVars.clear(); for (int i = 0; i < nparams_active; i++) { - p = m_act_rot_inds_[i].first; - q = m_act_rot_inds_[i].second; + const int p = m_act_rot_inds_[i].first; + const int q = m_act_rot_inds_[i].second; registerParameter(i, p, q, myVars, params_, true); if constexpr (IsComplex_t::value) registerParameter(i, p, q, myVars, params_, false); @@ -360,15 +359,7 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota const size_t N = m_full_rot_inds_.size(); myVarsFull_.resize(N); for (int i = 0; i < N; i++) - { - p = m_full_rot_inds_[i].first; - q = m_full_rot_inds_[i].second; - - if (params_supplied_) - myVarsFull_[i] = params_[i]; - else - myVarsFull_[i] = 0.0; - } + myVarsFull_[i] = params_supplied_ ? params_[i] : 0.0; } //Printing the parameters