diff --git a/freud/environment/export-MatchEnv.cc b/freud/environment/export-MatchEnv.cc index 13dc73c0e..61aa7d694 100644 --- a/freud/environment/export-MatchEnv.cc +++ b/freud/environment/export-MatchEnv.cc @@ -41,12 +41,10 @@ void compute_env_rmsd_min(const std::shared_ptr& env_r const locality::QueryArgs& qargs, const nb_array>& motif, const unsigned int motif_size, - const float threshold, const bool registration ) { auto* motif_data = reinterpret_cast*>(motif.data()); - // TODO: where should threshold go? env_rmsd_min->compute(nq, nlist, qargs, motif_data, motif_size, registration); } @@ -76,6 +74,73 @@ std::map compute_is_similar( return isSimilar(box, refPoints1_data, refPoints2_data, numRef, threshold_sq, registration); } +// TODO refactor to resuse code +nb::object getClusterEnv(const std::shared_ptr& env_cls) +{ + auto cluster_envs = env_cls->getClusterEnvironments(); + + // convert to list of of list of lists for returning to python + nb::list cluster_envs_python; + for (const auto& cluster_env:cluster_envs) + { + nb::list env; + for (const auto& cluster:cluster_env) + { + nb::list vec; + vec.append(cluster.x); + vec.append(cluster.y); + vec.append(cluster.z); + env.append(vec); + } + cluster_envs_python.append(env); + } + return cluster_envs_python; +} + +nb::object getPointEnv(const std::shared_ptr& env_cls) +{ + auto cluster_envs = env_cls->getPointEnvironments(); + + // convert to list of of list of lists for returning to python + nb::list cluster_envs_python; + for (const auto& cluster_env:cluster_envs) + { + nb::list env; + for (const auto& cluster:cluster_env) + { + nb::list vec; + vec.append(cluster.x); + vec.append(cluster.y); + vec.append(cluster.z); + env.append(vec); + } + cluster_envs_python.append(env); + } + return cluster_envs_python; +} + +nb::object getPointEnvmm(const std::shared_ptr& env_cls) +{ + auto cluster_envs = env_cls->getPointEnvironments(); + + // convert to list of of list of lists for returning to python + nb::list cluster_envs_python; + for (const auto& cluster_env:cluster_envs) + { + nb::list env; + for (const auto& cluster:cluster_env) + { + nb::list vec; + vec.append(cluster.x); + vec.append(cluster.y); + vec.append(cluster.z); + env.append(vec); + } + cluster_envs_python.append(env); + } + return cluster_envs_python; +} + }; namespace detail { @@ -92,19 +157,21 @@ void export_MatchEnv(nb::module_& module) nb::class_(module, "EnvironmentCluster") .def(nb::init<>()) - .def("compute", &EnvironmentCluster::compute) - // .def("getClusters", &EnvironmentCluster::getClusterIdx) // TODO: should be there - .def("getClusterEnvironments", &EnvironmentCluster::getClusterEnvironments) + .def("compute", &EnvironmentCluster::compute, nb::arg("nq"), nb::arg("nlist").none(), nb::arg("qargs"), nb::arg("env_nlist").none(), nb::arg("env_qargs"), nb::arg("threshold"), nb::arg("registration")) + .def("getClusterEnvironments", &wrap::getClusterEnv) + .def("getPointEnvironments", &wrap::getPointEnv) + .def("getClusters", &EnvironmentCluster::getClusters) .def("getNumClusters", &EnvironmentCluster::getNumClusters); nb::class_(module, "EnvironmentMotifMatch") .def(nb::init<>()) - .def("compute", &wrap::compute_env_motif_match, nb::arg("nq"), nb::arg("nlist"), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration")) + .def("compute", &wrap::compute_env_motif_match, nb::arg("nq"), nb::arg("nlist").none(), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration")) + .def("getPointEnvironments", &wrap::getPointEnvmm) .def("getMatches", &EnvironmentMotifMatch::getMatches); nb::class_(module, "EnvironmentRMSDMinimizer") .def(nb::init<>()) - .def("compute", &wrap::compute_env_rmsd_min, nb::arg("nq"), nb::arg("nlist"), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration")) + .def("compute", &wrap::compute_env_rmsd_min, nb::arg("nq"), nb::arg("nlist").none(), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("registration")) .def("getRMSDs", &EnvironmentRMSDMinimizer::getRMSDs); }