Skip to content

Commit

Permalink
Update hopping generator interface and allow matrix hoppings
Browse files Browse the repository at this point in the history
  • Loading branch information
dean0x7d committed Jul 20, 2017
1 parent b330269 commit 1f4b3d9
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 44 deletions.
1 change: 1 addition & 0 deletions cppcore/include/system/Registry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ inline MatrixXcd canonical_onsite_energy(MatrixXcd const& energy) { return energ
void check_hopping_energy(MatrixXcd const& energy);
/// Convert the hopping energy into the canonical format
MatrixXcd canonical_hopping_energy(std::complex<double> energy);
inline MatrixXcd canonical_hopping_energy(MatrixXcd const& energy) { return energy; }

} // namespace detail
} // namespace cpb
16 changes: 5 additions & 11 deletions cppcore/include/system/StructureModifiers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@ class PositionModifier {
PositionModifier(Function const& apply) : apply(apply) {}
};

/**
Helper class for passing sublattice information
*/
struct SubIdRef {
ArrayX<storage_idx_t> const& ids;
std::unordered_map<std::string, storage_idx_t> name_map;
};

/**
Introduces a new site family (with new sub_id)
Expand Down Expand Up @@ -82,14 +74,16 @@ class HoppingGenerator {
ArrayXi from;
ArrayXi to;
};
using Function = std::function<Result(cpb::CartesianArray const&, SubIdRef)>;
using Function = std::function<Result(System const&)>;

std::string name; ///< friendly hopping identifier - will be added to lattice registry
std::complex<double> energy; ///< hopping energy - also added to lattice registry
MatrixXcd energy; ///< hopping energy - also added to hopping registry
Function make; ///< function which will generate the new hopping index pairs

HoppingGenerator(std::string const& name, std::complex<double> energy, Function const& make)
HoppingGenerator(string_view name, MatrixXcd const& energy, Function const& make)
: name(name), energy(energy), make(make) {}
HoppingGenerator(string_view name, std::complex<double> energy, Function const& make)
: HoppingGenerator(name, MatrixXcd::Constant(1, 1, energy), make) {}

explicit operator bool() const { return static_cast<bool>(make); }
};
Expand Down
2 changes: 1 addition & 1 deletion cppcore/src/Model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void Model::add(SiteGenerator const& g) {

void Model::add(HoppingGenerator const& g) {
structure_modifiers.emplace_back(g);
hopping_registry.register_family(g.name, MatrixXcd::Constant(1, 1, g.energy));
hopping_registry.register_family(g.name, g.energy);
clear_structure();
}

Expand Down
7 changes: 3 additions & 4 deletions cppcore/src/system/StructureModifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ void apply(SiteGenerator const& g, System& s) {
void apply(HoppingGenerator const& g, System& s) {
detail::remove_invalid(s);

auto const sublattices = s.compressed_sublattices.decompressed();
auto const family_id = s.hopping_registry.id(g.name);
auto pairs = g.make(s.positions, {sublattices, s.site_registry.name_map()});
s.hopping_blocks.append(family_id, std::move(pairs.from), std::move(pairs.to));
auto pairs = g.make(s);
s.hopping_blocks.append(s.hopping_registry.id(g.name),
std::move(pairs.from), std::move(pairs.to));
}

} // namespace cpb
2 changes: 1 addition & 1 deletion cppcore/tests/fixtures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ cpb::HoppingModifier force_complex_numbers() {
namespace generator {

cpb::HoppingGenerator do_nothing_hopping(std::string const& name) {
return {name, 0, [](cpb::CartesianArray const&, SubIdRef) {
return {name, 0.0, [](System const&) {
return HoppingGenerator::Result{ArrayXi{}, ArrayXi{}};
}};
}
Expand Down
6 changes: 3 additions & 3 deletions cppcore/tests/test_modifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ TEST_CASE("HoppingGenerator") {
REQUIRE(model.system()->hopping_blocks.nnz() == 0);

SECTION("Add real generator") {
model.add(HoppingGenerator("t2", 2.0, [](CartesianArray const&, SubIdRef) {
model.add(HoppingGenerator("t2", 2.0, [](System const&) {
auto r = HoppingGenerator::Result{ArrayXi(1), ArrayXi(1)};
r.from << 0;
r.to << 1;
Expand All @@ -273,7 +273,7 @@ TEST_CASE("HoppingGenerator") {
}

SECTION("Add complex generator") {
model.add(HoppingGenerator("t2", {0.0, 1.0}, [](CartesianArray const&, SubIdRef) {
model.add(HoppingGenerator("t2", std::complex<double>{0.0, 1.0}, [](System const&) {
return HoppingGenerator::Result{ArrayXi(), ArrayXi()};
}));

Expand All @@ -282,7 +282,7 @@ TEST_CASE("HoppingGenerator") {
}

SECTION("Upper triangular form should be preserved") {
model.add(HoppingGenerator("t2", 2.0, [](CartesianArray const&, SubIdRef) {
model.add(HoppingGenerator("t2", 2.0, [](System const&) {
auto r = HoppingGenerator::Result{ArrayXi(2), ArrayXi(2)};
r.from << 0, 1;
r.to << 1, 0;
Expand Down
33 changes: 18 additions & 15 deletions cppmodule/src/modifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,23 @@ void init_site_generator(SiteGenerator& self, string_view name, T const& energy,
);
};

void wrap_modifiers(py::module& m) {
py::class_<SubIdRef>(m, "SubIdRef")
.def_property_readonly("ids", [](SubIdRef const& s) { return arrayref(s.ids); })
.def_readonly("name_map", &SubIdRef::name_map);
template<class T>
void init_hopping_generator(HoppingGenerator& self, std::string const& name,
T const& energy, py::object make) {
new (&self) HoppingGenerator(
name, detail::canonical_hopping_energy(energy),
[make](System const& s) {
py::gil_scoped_acquire guard{};
auto const& p = CartesianArrayConstRef(s.positions);
auto sites_type = py::module::import("pybinding.system").attr("_CppSites");
auto result = make(p.x(), p.y(), p.z(), sites_type(&s));
auto t = py::reinterpret_borrow<py::tuple>(result);
return HoppingGenerator::Result{t[0].cast<ArrayXi>(), t[1].cast<ArrayXi>()};
}
);
}

void wrap_modifiers(py::module& m) {
py::class_<SiteStateModifier>(m, "SiteStateModifier")
.def("__init__", [](SiteStateModifier& self, py::object apply, int min_neighbors) {
new (&self) SiteStateModifier(
Expand Down Expand Up @@ -90,17 +102,8 @@ void wrap_modifiers(py::module& m) {
.def("__init__", init_site_generator<MatrixXcd>);

py::class_<HoppingGenerator>(m, "HoppingGenerator")
.def("__init__", [](HoppingGenerator& self, std::string const& name,
std::complex<double> energy, py::object make) {
new (&self) HoppingGenerator(
name, energy,
[make](CartesianArray const& p, SubIdRef sub) {
py::gil_scoped_acquire guard{};
auto t = py::tuple(make(arrayref(p.x), arrayref(p.y), arrayref(p.z), sub));
return HoppingGenerator::Result{t[0].cast<ArrayXi>(), t[1].cast<ArrayXi>()};
}
);
});
.def("__init__", init_hopping_generator<std::complex<double>>)
.def("__init__", init_hopping_generator<MatrixXcd>);

py::class_<OnsiteModifier>(m, "OnsiteModifier")
.def("__init__", [](OnsiteModifier& self, py::object apply,
Expand Down
16 changes: 7 additions & 9 deletions pybinding/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import _cpp
from .system import Sites
from .support.inspect import get_call_signature
from .support.alias import AliasArray, AliasIndex, SplitName
from .support.alias import AliasIndex, SplitName
from .support.deprecated import LoudDeprecationWarning
from .utils.misc import decorator_decorator

Expand All @@ -39,9 +39,7 @@ def _process_modifier_args(args, keywords, requested_argnames):
orbs = 1, 1

def process(obj):
if isinstance(obj, _cpp.SubIdRef):
return AliasArray(obj.ids, obj.name_map)
elif isinstance(obj, str):
if isinstance(obj, str):
return AliasIndex(SplitName(obj), shape, orbs)
elif obj.size == shape[0]:
obj.shape = shape
Expand All @@ -51,7 +49,7 @@ def process(obj):

kwargs = {k: process(v) for k, v in zip(keywords, args) if k in requested_argnames}

if "sites" in requested_argnames:
if "sites" in requested_argnames and "sites" not in kwargs:
kwargs["sites"] = Sites((kwargs[k] for k in ("x", "y", "z")), kwargs["sub_id"])

return kwargs
Expand Down Expand Up @@ -510,14 +508,14 @@ def hopping_generator(name, energy):
x, y, z : np.ndarray
Lattice site position.
sub_id : np.ndarray
Sublattice identifier: can be checked for equality with sublattice names
specified in :class:`.Lattice`.
sites : :class:`.Sites`
Information about sites families, positions and various utility functions.
See :class:`.Sites` for details.
The function must return:
Tuple[np.ndarray, np.ndarray]
Arrays of index pairs which form the new hoppings.
"""
return functools.partial(_make_generator, kind=_cpp.HoppingGenerator,
name=name, energy=energy, keywords="x, y, z, sub_id")
name=name, energy=energy, keywords="x, y, z, sites")

0 comments on commit 1f4b3d9

Please sign in to comment.