Skip to content

Commit

Permalink
Allow site state and position modifiers after generators
Browse files Browse the repository at this point in the history
  • Loading branch information
dean0x7d committed Jul 14, 2017
1 parent affad64 commit 3b6d08c
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 14 deletions.
15 changes: 15 additions & 0 deletions cppcore/include/numeric/dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,4 +318,19 @@ ArrayX<T> make_integer_range(idx_t size) {
return result;
}

template<class Vector, class Bools>
Vector slice(Vector const& v, Bools const& keep) {
using std::begin; using std::end;
auto const original_size = v.size();
auto const result_size = std::accumulate(begin(keep), end(keep), idx_t{0});

auto result = Vector(result_size);
auto count = 0;
for (auto i = idx_t{0}; i < original_size; ++i) {
if (keep[i])
result[count++] = v[i];
}
return result;
};

} // namespace cpb
4 changes: 4 additions & 0 deletions cppcore/include/system/CompressedSublattices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class CompressedSublattices {

/// Start a new sublattice block or increment the site count for the existing block
void add(SubAliasID id, idx_t norb);

/// Remove sites for which `keep == false`
void filter(VectorX<bool> const& keep);

/// Verify that the stored data is correct: `sum(site_counts) == num_sites`
void verify(idx_t num_sites) const;

Expand Down
3 changes: 3 additions & 0 deletions cppcore/include/system/HoppingBlocks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class HoppingBlocks {
/// Append a range of coordinates to the given family block
void append(HopID family_id, ArrayXi&& rows, ArrayXi&& cols);

/// Remove sites for which `keep == false`
void filter(VectorX<bool> const& keep);

/// Return the matrix in the CSR sparse matrix format
SparseMatrixX<storage_idx_t> tocsr() const;

Expand Down
2 changes: 2 additions & 0 deletions cppcore/include/system/StructureModifiers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ void apply(SiteStateModifier const& m, Foundation& f);
void apply(PositionModifier const& m, Foundation& f);

template<class M> void apply(M const&, System&) {}
void apply(SiteStateModifier const& m, System& s);
void apply(PositionModifier const& m, System& s);
void apply(HoppingGenerator const& g, System& s);

template<class M> constexpr bool requires_system(M const&) { return false; }
Expand Down
2 changes: 2 additions & 0 deletions cppcore/include/system/System.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct System {
CompressedSublattices compressed_sublattices;
HoppingBlocks hopping_blocks;
std::vector<Boundary> boundaries;
ArrayX<bool> is_valid;

System(Lattice const& lattice) : lattice(lattice) {}

Expand Down Expand Up @@ -69,6 +70,7 @@ namespace detail {
void populate_system(System& system, Foundation const& foundation);
void populate_boundaries(System& system, Foundation const& foundation,
TranslationalSymmetry const& symmetry);
void remove_invalid(System& system);
} // namespace detail

} // namespace cpb
2 changes: 2 additions & 0 deletions cppcore/src/Model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ std::shared_ptr<System> Model::make_system() const {
apply(modifier, *sys);
}

detail::remove_invalid(*sys);

if (sys->num_sites() == 0) { throw std::runtime_error{"Impossible system: 0 sites"}; }

return sys;
Expand Down
16 changes: 16 additions & 0 deletions cppcore/src/system/CompressedSublattices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ void CompressedSublattices::add(SubAliasID id, idx_t norb) {
}
}

void CompressedSublattices::filter(VectorX<bool> const& keep) {
using std::begin;

auto new_counts = std::vector<storage_idx_t>();
new_counts.reserve(data.size());

for (auto const& sub : *this) {
new_counts.push_back(std::accumulate(begin(keep) + sub.sys_start(),
begin(keep) + sub.sys_end(), storage_idx_t{0}));
}

for (auto i = size_t{0}; i < data.size(); ++i) {
data[i].num_sites = new_counts[i];
}
}

void CompressedSublattices::verify(idx_t num_sites) const {
using std::begin; using std::end;

Expand Down
11 changes: 11 additions & 0 deletions cppcore/src/system/HoppingBlocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ void HoppingBlocks::append(HopID family_id, ArrayXi&& rows, ArrayXi&& cols) {
block.erase(std::unique(block.begin(), block.end()), block.end());
}

void HoppingBlocks::filter(VectorX<bool> const& keep) {
using std::begin; using std::end;

num_sites = std::accumulate(begin(keep), end(keep), idx_t{0});
for (auto& block : blocks) {
block.erase(std::remove_if(block.begin(), block.end(), [&](COO coo) {
return !keep[coo.row] || !keep[coo.col];
}), block.end());
}
}

SparseMatrixX<storage_idx_t> HoppingBlocks::tocsr() const {
auto csr = SparseMatrixX<storage_idx_t>(num_sites, num_sites);
csr.reserve(nnz());
Expand Down
26 changes: 26 additions & 0 deletions cppcore/src/system/StructureModifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,40 @@ void apply(SiteStateModifier const& m, Foundation& f) {
}
}

void apply(SiteStateModifier const& m, System& s) {
if (s.is_valid.size() == 0) {
s.is_valid = ArrayX<bool>::Constant(s.num_sites(), true);
}

for (auto const& sub : s.compressed_sublattices) {
m.apply(s.is_valid.segment(sub.sys_start(), sub.num_sites()),
s.positions.segment(sub.sys_start(), sub.num_sites()),
s.lattice.sublattice_name(SubID(sub.alias_id())));
}

if (m.min_neighbors > 0) {
throw std::runtime_error("Eliminating dangling bonds after a generator "
"has not been implemented yet");
}
}

void apply(PositionModifier const& m, Foundation& f) {
for (auto const& pair : f.get_lattice().get_sublattices()) {
auto slice = f[pair.second.unique_id];
m.apply(slice.get_positions(), pair.first);
}
}

void apply(PositionModifier const& m, System& s) {
for (auto const& sub : s.compressed_sublattices) {
m.apply(s.positions.segment(sub.sys_start(), sub.num_sites()),
s.lattice.sublattice_name(SubID(sub.alias_id())));
}
}

void apply(HoppingGenerator const& g, System& s) {
detail::remove_invalid(s);

auto const& lattice = s.lattice;
auto const sublattices = s.compressed_sublattices.decompressed();
auto const family_id = lattice.hopping_family(g.name).family_id;
Expand Down
16 changes: 16 additions & 0 deletions cppcore/src/system/System.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,21 @@ void populate_boundaries(System& system, Foundation const& foundation,
}
}

void remove_invalid(System& s) {
if (s.is_valid.size() == 0) { return; }

s.positions.x = slice(s.positions.x, s.is_valid);
s.positions.y = slice(s.positions.y, s.is_valid);
s.positions.z = slice(s.positions.z, s.is_valid);
s.compressed_sublattices.filter(s.is_valid);
s.hopping_blocks.filter(s.is_valid);

for (auto& b : s.boundaries) {
b.hopping_blocks.filter(s.is_valid);
}

s.is_valid.resize(0);
}

} // namespace detail
} // namespace cpb
11 changes: 11 additions & 0 deletions cppcore/tests/fixtures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,14 @@ cpb::HoppingModifier force_complex_numbers() {
}

} // namespace field


namespace generator {

cpb::HoppingGenerator do_nothing_hopping(std::string const& name) {
return {name, 0, [](cpb::CartesianArray const&, SubIdRef) {
return HoppingGenerator::Result{ArrayXi{}, ArrayXi{}};
}};
}

} // namespace generator
6 changes: 6 additions & 0 deletions cppcore/tests/fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ cpb::HoppingModifier force_double_precision();
cpb::HoppingModifier force_complex_numbers();

} // namespace field

namespace generator {

cpb::HoppingGenerator do_nothing_hopping(std::string const& name = "_t");

} // namespace generator
86 changes: 72 additions & 14 deletions cppcore/tests/test_modifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,87 @@ TEST_CASE("SiteStateModifier") {
auto model = Model(lattice::square_2atom(), Primitive(2));
REQUIRE(model.system()->num_sites() == 4);

auto remove_site = [](Eigen::Ref<ArrayX<bool>> state, CartesianArrayConstRef, string_view s) {
if (s == "A") {
auto count = std::unordered_map<std::string, idx_t>();

auto remove_site = [&](Eigen::Ref<ArrayX<bool>> state, CartesianArrayConstRef, string_view s) {
count[s] = state.size();
if (s == "A" && state.size() != 0) {
state[0] = false;
}
};
model.add(SiteStateModifier(remove_site));
REQUIRE(model.system()->num_sites() == 3);
model.add(SiteStateModifier(remove_site, 1));
REQUIRE(model.system()->num_sites() == 2);
model.add(SiteStateModifier(remove_site, 2));
REQUIRE_THROWS(model.system());

SECTION("Apply to foundation") {
model.add(SiteStateModifier(remove_site));
REQUIRE(model.system()->num_sites() == 3);
REQUIRE(count["A"] == 2);
REQUIRE(count["B"] == 2);

model.add(SiteStateModifier(remove_site, 1));
REQUIRE(model.system()->num_sites() == 2);
REQUIRE(count["A"] == 2);
REQUIRE(count["B"] == 2);

model.add(SiteStateModifier(remove_site, 2));
REQUIRE_THROWS(model.system());
REQUIRE(count["A"] == 2);
REQUIRE(count["B"] == 2);
}

SECTION("Apply to system") {
model.add(SiteStateModifier(remove_site));
model.add(generator::do_nothing_hopping());

model.add(SiteStateModifier(remove_site));
REQUIRE(model.system()->num_sites() == 2);
REQUIRE(count["A"] == 1);
REQUIRE(count["B"] == 2);

model.add(generator::do_nothing_hopping("_t2"));
model.add(SiteStateModifier(remove_site));
REQUIRE(model.system()->num_sites() == 2);
REQUIRE(count["A"] == 0);
REQUIRE(count["B"] == 2);

model.add(SiteStateModifier(remove_site, 1));
REQUIRE_THROWS_WITH(model.system(), Catch::Contains("has not been implemented yet"));
REQUIRE(count["A"] == 0);
REQUIRE(count["B"] == 2);
}
}

TEST_CASE("SitePositionModifier") {
auto model = Model(lattice::square_2atom());
REQUIRE(model.system()->positions.y[1] == Approx(0.5));
auto model = Model(lattice::square_2atom(), shape::rectangle(2, 2));
REQUIRE(model.system()->num_sites() == 6);
REQUIRE(model.system()->positions.y[1] == Approx(-1));

auto count = std::unordered_map<std::string, idx_t>();
constexpr auto moved_pos = 10.0f;

model.add(PositionModifier([](CartesianArrayRef position, string_view sublattice) {
auto move_site = PositionModifier([&](CartesianArrayRef position, string_view sublattice) {
count[sublattice] = position.size();
if (sublattice == "B") {
position.y()[0] = 1;
position.y().setConstant(moved_pos);
}
}));
REQUIRE(model.system()->positions.y[1] == Approx(1));
});

SECTION("Apply to foundation") {
model.add(move_site);
model.eval();
REQUIRE(count["A"] == 25);
REQUIRE(count["B"] == 25);
REQUIRE(model.system()->num_sites() == 6);
REQUIRE(model.system()->positions.y.segment(4, 2).isApproxToConstant(moved_pos));
}

SECTION("Apply to system") {
model.add(generator::do_nothing_hopping());
model.add(move_site);
model.eval();
REQUIRE(count["A"] == 4);
REQUIRE(count["B"] == 2);
REQUIRE(model.system()->num_sites() == 6);
REQUIRE(model.system()->positions.y.segment(4, 2).isApproxToConstant(moved_pos));
}
}

TEST_CASE("State and position modifier ordering") {
Expand Down

0 comments on commit 3b6d08c

Please sign in to comment.