Skip to content

Commit

Permalink
fix exports
Browse files Browse the repository at this point in the history
  • Loading branch information
DomFijan committed Jan 21, 2025
1 parent c3df090 commit f9d7f27
Showing 1 changed file with 74 additions and 7 deletions.
81 changes: 74 additions & 7 deletions freud/environment/export-MatchEnv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ void compute_env_rmsd_min(const std::shared_ptr<EnvironmentRMSDMinimizer>& env_r
const locality::QueryArgs& qargs,
const nb_array<float, nanobind::shape<-1, 3>>& motif,
const unsigned int motif_size,
const float threshold,
const bool registration
)
{
auto* motif_data = reinterpret_cast<vec3<float>*>(motif.data());
// TODO: where should threshold go?
env_rmsd_min->compute(nq, nlist, qargs, motif_data, motif_size, registration);
}

Expand Down Expand Up @@ -76,6 +74,73 @@ std::map<unsigned int, unsigned int> compute_is_similar(
return isSimilar(box, refPoints1_data, refPoints2_data, numRef, threshold_sq, registration);
}

// TODO refactor to resuse code
nb::object getClusterEnv(const std::shared_ptr<EnvironmentCluster>& env_cls)
{
auto cluster_envs = env_cls->getClusterEnvironments();

// convert to list of of list of lists for returning to python
nb::list cluster_envs_python;
for (const auto& cluster_env:cluster_envs)
{
nb::list env;
for (const auto& cluster:cluster_env)
{
nb::list vec;
vec.append(cluster.x);
vec.append(cluster.y);
vec.append(cluster.z);
env.append(vec);
}
cluster_envs_python.append(env);
}
return cluster_envs_python;
}

nb::object getPointEnv(const std::shared_ptr<EnvironmentCluster>& env_cls)
{
auto cluster_envs = env_cls->getPointEnvironments();

// convert to list of of list of lists for returning to python
nb::list cluster_envs_python;
for (const auto& cluster_env:cluster_envs)
{
nb::list env;
for (const auto& cluster:cluster_env)
{
nb::list vec;
vec.append(cluster.x);
vec.append(cluster.y);
vec.append(cluster.z);
env.append(vec);
}
cluster_envs_python.append(env);
}
return cluster_envs_python;
}

nb::object getPointEnvmm(const std::shared_ptr<EnvironmentMotifMatch>& env_cls)
{
auto cluster_envs = env_cls->getPointEnvironments();

// convert to list of of list of lists for returning to python
nb::list cluster_envs_python;
for (const auto& cluster_env:cluster_envs)
{
nb::list env;
for (const auto& cluster:cluster_env)
{
nb::list vec;
vec.append(cluster.x);
vec.append(cluster.y);
vec.append(cluster.z);
env.append(vec);
}
cluster_envs_python.append(env);
}
return cluster_envs_python;
}

};

namespace detail {
Expand All @@ -92,19 +157,21 @@ void export_MatchEnv(nb::module_& module)

nb::class_<EnvironmentCluster>(module, "EnvironmentCluster")
.def(nb::init<>())
.def("compute", &EnvironmentCluster::compute)
// .def("getClusters", &EnvironmentCluster::getClusterIdx) // TODO: should be there
.def("getClusterEnvironments", &EnvironmentCluster::getClusterEnvironments)
.def("compute", &EnvironmentCluster::compute, nb::arg("nq"), nb::arg("nlist").none(), nb::arg("qargs"), nb::arg("env_nlist").none(), nb::arg("env_qargs"), nb::arg("threshold"), nb::arg("registration"))
.def("getClusterEnvironments", &wrap::getClusterEnv)
.def("getPointEnvironments", &wrap::getPointEnv)
.def("getClusters", &EnvironmentCluster::getClusters)
.def("getNumClusters", &EnvironmentCluster::getNumClusters);

nb::class_<EnvironmentMotifMatch>(module, "EnvironmentMotifMatch")
.def(nb::init<>())
.def("compute", &wrap::compute_env_motif_match, nb::arg("nq"), nb::arg("nlist"), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration"))
.def("compute", &wrap::compute_env_motif_match, nb::arg("nq"), nb::arg("nlist").none(), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration"))
.def("getPointEnvironments", &wrap::getPointEnvmm)
.def("getMatches", &EnvironmentMotifMatch::getMatches);

nb::class_<EnvironmentRMSDMinimizer>(module, "EnvironmentRMSDMinimizer")
.def(nb::init<>())
.def("compute", &wrap::compute_env_rmsd_min, nb::arg("nq"), nb::arg("nlist"), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration"))
.def("compute", &wrap::compute_env_rmsd_min, nb::arg("nq"), nb::arg("nlist").none(), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("registration"))
.def("getRMSDs", &EnvironmentRMSDMinimizer::getRMSDs);

}
Expand Down

0 comments on commit f9d7f27

Please sign in to comment.