Skip to content

Commit

Permalink
Give generators the entire system as an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
dean0x7d committed Jul 20, 2017
1 parent 1f4b3d9 commit 469c4fe
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 40 deletions.
4 changes: 1 addition & 3 deletions cppcore/include/system/StructureModifiers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ class PositionModifier {
*/
class SiteGenerator {
public:
using Function = std::function<
CartesianArray(CartesianArrayConstRef, CompressedSublattices const&, HoppingBlocks const&)
>;
using Function = std::function<CartesianArray(System const&)>;

std::string name; ///< friendly site family identifier
MatrixXcd energy; ///< onsite energy - also added to the site registry
Expand Down
2 changes: 1 addition & 1 deletion cppcore/src/system/StructureModifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void apply(PositionModifier const& m, System& s) {
void apply(SiteGenerator const& g, System& s) {
detail::remove_invalid(s);

auto const new_positions = g.make(s.positions, s.compressed_sublattices, s.hopping_blocks);
auto const new_positions = g.make(s);
auto const norb = g.energy.rows();
auto const nsites = new_positions.size();
s.compressed_sublattices.add(s.site_registry.id(g.name), norb, nsites);
Expand Down
7 changes: 2 additions & 5 deletions cppcore/tests/test_modifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ TEST_CASE("SiteGenerator") {
REQUIRE(model.system()->hopping_blocks.nnz() == 0);

SECTION("Errors") {
auto const noop = [](CartesianArrayConstRef, CompressedSublattices const&,
HoppingBlocks const&) { return CartesianArray(); };
auto const noop = [](System const&) { return CartesianArray(); };

auto const complex_vector = MatrixXcd::Constant(1, 2, 2.0);
REQUIRE_THROWS_WITH(model.add(SiteGenerator("C", complex_vector, noop)),
Expand All @@ -204,9 +203,7 @@ TEST_CASE("SiteGenerator") {

SECTION("Structure") {
auto const energy = MatrixXcd::Constant(1, 1, 2.0);
model.add(SiteGenerator("C", energy, [](CartesianArrayConstRef,
CompressedSublattices const&,
HoppingBlocks const&) {
model.add(SiteGenerator("C", energy, [](System const&) {
auto const size = 5;
auto x = ArrayXf::Constant(size, 1);
auto y = ArrayXf::LinSpaced(size, 1, 5);
Expand Down
14 changes: 6 additions & 8 deletions cppmodule/src/modifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ void extract_modifier_result(T& v, py::object const& o) {

template<class T>
void init_site_generator(SiteGenerator& self, string_view name, T const& energy, py::object make) {
auto system_type = py::module::import("pybinding.system").attr("System");
new (&self) SiteGenerator(
name, detail::canonical_onsite_energy(energy),
[make](CartesianArrayConstRef p, CompressedSublattices const& s, HoppingBlocks const& h) {
[make, system_type](System const& s) {
py::gil_scoped_acquire guard{};
auto result = make(p.x(), p.y(), p.z(), &s, &h);
auto t = py::reinterpret_borrow<py::tuple>(result);
auto t = make(system_type(&s)).cast<py::tuple>();
return CartesianArray(t[0].cast<ArrayXf>(),
t[1].cast<ArrayXf>(),
t[2].cast<ArrayXf>());
Expand All @@ -57,14 +57,12 @@ void init_site_generator(SiteGenerator& self, string_view name, T const& energy,
template<class T>
void init_hopping_generator(HoppingGenerator& self, std::string const& name,
T const& energy, py::object make) {
auto system_type = py::module::import("pybinding.system").attr("System");
new (&self) HoppingGenerator(
name, detail::canonical_hopping_energy(energy),
[make](System const& s) {
[make, system_type](System const& s) {
py::gil_scoped_acquire guard{};
auto const& p = CartesianArrayConstRef(s.positions);
auto sites_type = py::module::import("pybinding.system").attr("_CppSites");
auto result = make(p.x(), p.y(), p.z(), sites_type(&s));
auto t = py::reinterpret_borrow<py::tuple>(result);
auto t = make(system_type(&s)).cast<py::tuple>();
return HoppingGenerator::Result{t[0].cast<ArrayXi>(), t[1].cast<ArrayXi>()};
}
);
Expand Down
50 changes: 27 additions & 23 deletions pybinding/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,39 @@ def _process_modifier_args(args, keywords, requested_argnames):
Also process any special args like 'sub_id', 'hop_id' and 'sites'.
"""
prime_arg = args[0]
if prime_arg.ndim > 1:
# Move axis so that sites are first -- makes a nicer modifier interface
norb1, norb2, nsites = prime_arg.shape
prime_arg = np.moveaxis(prime_arg, 2, 0)
args = [prime_arg] + list(args[1:])
shape = nsites, 1, 1
orbs = norb1, norb2
else:
shape = prime_arg.shape
orbs = 1, 1
if isinstance(prime_arg, np.ndarray):
if prime_arg.ndim > 1:
# Move axis so that sites are first -- makes a nicer modifier interface
norb1, norb2, nsites = prime_arg.shape
prime_arg = np.moveaxis(prime_arg, 2, 0)
args = [prime_arg] + list(args[1:])
shape = nsites, 1, 1
orbs = norb1, norb2
else:
shape = prime_arg.shape
orbs = 1, 1

def process(obj):
if isinstance(obj, str):
return AliasIndex(SplitName(obj), shape, orbs)
elif obj.size == shape[0]:
elif isinstance(obj, np.ndarray) and obj.size == shape[0]:
obj.shape = shape
return obj
else:
return obj

kwargs = {k: process(v) for k, v in zip(keywords, args) if k in requested_argnames}
kwargs = dict(zip(keywords, args))
requested_kwargs = {k: process(v) for k, v in kwargs.items()
if k in requested_argnames}

if "sites" in requested_argnames and "sites" not in kwargs:
kwargs["sites"] = Sites((kwargs[k] for k in ("x", "y", "z")), kwargs["sub_id"])
requested_kwargs["sites"] = Sites((kwargs[k] for k in ("x", "y", "z")), kwargs["sub_id"])

if "system" in keywords:
requested_kwargs.update({p: getattr(kwargs["system"], p) for p in "xyz"
if p in requested_argnames})

return kwargs
return requested_kwargs


def _check_modifier_spec(func, keywords, has_sites=False):
Expand Down Expand Up @@ -474,18 +481,16 @@ def site_generator(name, energy):
x, y, z : np.ndarray
Lattice site position.
sublattices : CompressedSublattices
TBD
hoppings : Hoppings
TBD
system : :class:`.System`
Structural data of the model constructed so far. See :class:`.System` for details.
The function must return:
Tuple[np.ndarray, np.ndarray, np.ndarray]
Tuple of (x, y, z) arrays which indicate the positions of the new sites.
"""
return functools.partial(_make_generator, kind=_cpp.SiteGenerator,
name=name, energy=energy, keywords="x, y, z, sublattices, hoppings")
name=name, energy=energy, keywords="system, x, y, z")


@decorator_decorator
Expand All @@ -508,14 +513,13 @@ def hopping_generator(name, energy):
x, y, z : np.ndarray
Lattice site position.
sites : :class:`.Sites`
Information about sites families, positions and various utility functions.
See :class:`.Sites` for details.
system : :class:`.System`
Structural data of the model constructed so far. See :class:`.System` for details.
The function must return:
Tuple[np.ndarray, np.ndarray]
Arrays of index pairs which form the new hoppings.
"""
return functools.partial(_make_generator, kind=_cpp.HoppingGenerator,
name=name, energy=energy, keywords="x, y, z, sites")
name=name, energy=energy, keywords="system, x, y, z")

0 comments on commit 469c4fe

Please sign in to comment.