Skip to content

Commit

Permalink
Merge pull request #7 from camelto2/kappa_list_unpacking_for_ray
Browse files Browse the repository at this point in the history
fix to derivatives
  • Loading branch information
jptowns authored Dec 6, 2023
2 parents ab955c4 + 06bed8b commit 4ae4e66
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 58 deletions.
95 changes: 50 additions & 45 deletions src/QMCWaveFunctions/RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ void RotatedSPOs::setRotationParameters(const std::vector<RealType>& param_list)
params_.resize(num_param);

//handling both real and complex by casting data to RealType*, since laid out as real_0, imag_0, real_1, imag_1, etc
auto* params_data_real = (RealType*)params_.data();
std::copy(param_list.begin(), param_list.end(), params_data_real);
auto* params_data_alias = (RealType*)params_.data();
std::copy(param_list.begin(), param_list.end(), params_data_alias);

params_supplied_ = true;
}
Expand Down Expand Up @@ -111,29 +111,25 @@ void RotatedSPOs::extractParamsFromAntiSymmetricMatrix(const RotationIndices& ro

void RotatedSPOs::resetParametersExclusive(const opt_variables_type& active)
{
const size_t N = m_act_rot_inds_.size();
std::vector<ValueType> delta_param(N);

const size_t psize = use_global_rot_ ? m_full_rot_inds_.size() : N;
assert(psize >= N);

std::vector<ValueType> new_param(psize);
const size_t nact_rot = m_act_rot_inds_.size();
std::vector<ValueType> delta_param(nact_rot);

//cast ValueType to RealType for delta param
//allows us to work with both real and complex since
//active and myVars are stored as only reals
auto* delta_param_data_real = (RealType*)delta_param.data();
auto* delta_param_data_alias = (RealType*)delta_param.data();
for (int i = 0; i < myVars.size(); i++)
{
int loc = myVars.where(i);
delta_param_data_real[i] = active[loc] - myVars[i];
myVars[i] = active[loc];
int loc = myVars.where(i);
delta_param_data_alias[i] = active[loc] - myVars[i];
myVars[i] = active[loc];
}

if (use_global_rot_)
{
std::vector<ValueType> old_param(psize);
std::vector<ValueType> old_param(m_full_rot_inds_.size());
std::copy_n(myVarsFull_.data(), myVarsFull_.size(), old_param.data());

applyDeltaRotation(delta_param, old_param, myVarsFull_);
}
else
Expand All @@ -151,9 +147,8 @@ void RotatedSPOs::writeVariationalParameters(hdf_archive& hout)
if (use_global_rot_)
{
hout.push("rotation_global");
std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName();
const std::string rot_global_name = std::string("rotation_global_") + SPOSet::getName();

//Store in h5 as ValueType
hout.write(myVarsFull_, rot_global_name);
hout.pop();
}
Expand All @@ -180,9 +175,9 @@ void RotatedSPOs::writeVariationalParameters(hdf_archive& hout)
std::string rot_params_name = std::string("rotation_params_") + SPOSet::getName();

std::vector<ValueType> params(m_act_rot_inds_.size());
auto* params_data_real = (RealType*)params.data();
auto* params_data_alias = (RealType*)params.data();
for (int i = 0; i < myVars.size(); i++)
params_data_real[i] = myVars[i];
params_data_alias[i] = myVars[i];

hout.write(params, rot_params_name);
hout.pop();
Expand All @@ -209,7 +204,6 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)
if (!hin.getShape<ValueType>(rot_global_name, sizes))
throw std::runtime_error("Failed to read rotation_global in VP file");

//h5 is storing a std::vector<ComplexType> vec as {vec.size(), 2}
if (myVarsFull_.size() != sizes[0])
{
std::ostringstream tmp_err;
Expand All @@ -218,6 +212,7 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)
throw std::runtime_error(tmp_err.str());
}
hin.read(myVarsFull_, rot_global_name);

hin.pop();

applyFullRotation(myVarsFull_, true);
Expand All @@ -227,7 +222,7 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)
hin.push("rotation_history", false);
std::string rot_hist_name = std::string("rotation_history_") + SPOSet::getName();
std::vector<int> sizes(2);
if (!hin.getShape<RealType>(rot_hist_name, sizes))
if (!hin.getShape<ValueType>(rot_hist_name, sizes))
throw std::runtime_error("Failed to read rotation history in VP file");

int rows = sizes[0];
Expand All @@ -251,7 +246,7 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)
std::string rot_param_name = std::string("rotation_params_") + SPOSet::getName();

std::vector<int> sizes(1);
if (!hin.getShape<RealType>(rot_param_name, sizes))
if (!hin.getShape<ValueType>(rot_param_name, sizes))
throw std::runtime_error("Failed to read rotation_params in VP file");

//values stored as ValueType. Now unpack into reals
Expand All @@ -267,9 +262,9 @@ void RotatedSPOs::readVariationalParameters(hdf_archive& hin)

std::vector<ValueType> params(sizes[0]);
hin.read(params, rot_param_name);
auto* params_data_real = (RealType*)params.data();
auto* params_data_alias = (RealType*)params.data();
for (int i = 0; i < nparam; i++)
myVars[i] = params_data_real[i];
myVars[i] = params_data_alias[i];

hin.pop();

Expand Down Expand Up @@ -356,10 +351,10 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota

if (use_global_rot_)
{
const size_t N = m_full_rot_inds_.size();
myVarsFull_.resize(N);
for (int i = 0; i < N; i++)
myVarsFull_[i] = params_supplied_ ? params_[i] : 0.0;
const size_t nfull_rot = m_full_rot_inds_.size();
myVarsFull_.resize(nfull_rot);
for (int i = 0; i < nfull_rot; i++)
myVarsFull_[i] = (params_supplied_ && i < m_act_rot_inds_.size()) ? params_[i] : 0.0;
}

//Printing the parameters
Expand All @@ -374,10 +369,10 @@ void RotatedSPOs::buildOptVariables(const RotationIndices& rotations, const Rota
const size_t N = m_act_rot_inds_.size();
std::vector<ValueType> param(N);
//cast as RealType to copy from myVars into real or complex param vector
auto* param_data = (RealType*)param.data();
auto* param_data_alias = (RealType*)param.data();
//couldn't easily use std::copy since myVars is vector of pairs
for (size_t i = 0; i < myVars.size(); i++)
param_data[i] = myVars[i];
param_data_alias[i] = myVars[i];
apply_rotation(param, false);
}
}
Expand Down Expand Up @@ -700,16 +695,18 @@ void RotatedSPOs::evaluateDerivRatios(const VirtualParticleSet& VP,
// This multiply could be reduced to Ainv and the non-square part of A.
BLAS::gemm('N', 'N', nmo, nel, nel, ValueType(1.0), A, nmo, Ainv, nel, ValueType(0.0), T.data(), nmo);

for (int i = 0; i < m_act_rot_inds_.size(); i++)
for (int i = 0; i < myVars.size(); i++)
{
int kk = myVars.where(i);
if (kk >= 0)
{
const int p = m_act_rot_inds_.at(i).first;
const int q = m_act_rot_inds_.at(i).second;
int j = IsComplex_t<ValueType>::value ? i / 2 : i;
const int p = m_act_rot_inds_.at(j).first;
const int q = m_act_rot_inds_.at(j).second;
dratios(iat, kk) = T(p, q) - T_orig(p, q); // dratio size is (nknot, num_vars)
#ifdef QMC_COMPLEX
dratios(iat, kk + 1) = ComplexType(0, 1) * (T(p, q) - T_orig(p, q)); // dratio size is (nknot, num_vars)
if (i % 2 == 1)
dratios(iat, kk) *= ComplexType(0, 1);
#endif
}
}
Expand Down Expand Up @@ -753,16 +750,18 @@ void RotatedSPOs::evaluateDerivativesWF(ParticleSet& P,

BLAS::gemm('N', 'N', nmo, nel, nel, ValueType(1.0), A, nmo, Ainv, nel, ValueType(0.0), T.data(), nmo);

for (int i = 0; i < m_act_rot_inds_.size(); i++)
for (int i = 0; i < myVars.size(); i++)
{
int kk = myVars.where(i);
if (kk >= 0)
{
const int p = m_act_rot_inds_.at(i).first;
const int q = m_act_rot_inds_.at(i).second;
int j = IsComplex_t<ValueType>::value ? i / 2 : i;
const int p = m_act_rot_inds_.at(j).first;
const int q = m_act_rot_inds_.at(j).second;
dlogpsi[kk] = T(p, q);
#ifdef QMC_COMPLEX
dlogpsi[kk + 1] = ComplexType(0, 1) * T(p, q);
if (i % 2 == 1)
dlogpsi[kk] *= ComplexType(0, 1);
#endif
}
}
Expand Down Expand Up @@ -862,20 +861,26 @@ void RotatedSPOs::evaluateDerivatives(ParticleSet& P,
//possibly replace with BLAS call
Y4 = Y3 - Y2;

for (int i = 0; i < m_act_rot_inds_.size(); i++)
for (int i = 0; i < myVars.size(); i++)
{
int kk = myVars.where(i);
if (kk >= 0)
{
const int p = m_act_rot_inds_.at(i).first;
const int q = m_act_rot_inds_.at(i).second;
dlogpsi[kk] += T(p, q);
dhpsioverpsi[kk] += ValueType(-0.5) * Y4(p, q);
int j = IsComplex_t<ValueType>::value ? i / 2 : i;
const int p = m_act_rot_inds_.at(j).first;
const int q = m_act_rot_inds_.at(j).second;

ValueType pref1 = ValueType(1.0);
ValueType pref2 = ValueType(-0.5);
#ifdef QMC_COMPLEX
//imaginary part should be adjacent to real part
dlogpsi[kk + 1] += ComplexType(0, 1) * T(p, q);
dhpsioverpsi[kk + 1] += ComplexType(0, 1) * ValueType(-0.5) * Y4(p, q);
if (i % 2 == 1)
{
pref1 *= ComplexType(0, 1);
pref2 *= ComplexType(0, 1);
}
#endif
dlogpsi[kk] += pref1 * T(p, q);
dhpsioverpsi[kk] += pref2 * Y4(p, q);
}
}
}
Expand Down
21 changes: 8 additions & 13 deletions src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,12 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
for (size_t i = 0; i < vs.size(); i++)
CHECK(var[i] == Approx(vs[i]));

/*
opt_variables_type& full_var = testing::getMyVarsFull(rot2);
CHECK(full_var[0] == Approx(vs[0]));
CHECK(full_var[1] == Approx(vs[1]));
CHECK(full_var[2] == Approx(vs[2]));
CHECK(full_var[3] == Approx(vs[3]));
CHECK(full_var[4] == Approx(0.0));
CHECK(full_var[5] == Approx(0.0));
*/
//add extra parameters for full set
vs_values.push_back(0.0);
vs_values.push_back(0.0);
std::vector<SPOSet::ValueType>& full_var = testing::getMyVarsFull(rot2);
for (size_t i = 0; i < full_var.size(); i++)
CHECK(full_var[i] == ValueApprox(vs_values[i]));
}

// Test using history list.
Expand Down Expand Up @@ -830,10 +827,8 @@ TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
rot2.readVariationalParameters(hin);

opt_variables_type& var = testing::getMyVars(rot2);
CHECK(var[0] == Approx(vs[0]));
CHECK(var[1] == Approx(vs[1]));
CHECK(var[2] == Approx(vs[2]));
CHECK(var[3] == Approx(vs[3]));
for (size_t i = 0; i < var.size(); i++)
CHECK(var[i] == Approx(vs[i]));

auto hist = testing::getHistoryParams(rot2);
REQUIRE(hist.size() == 1);
Expand Down

0 comments on commit 4ae4e66

Please sign in to comment.