diff --git a/rascaline-torch/CMakeLists.txt b/rascaline-torch/CMakeLists.txt index bf14f9cef..b2096e2a6 100644 --- a/rascaline-torch/CMakeLists.txt +++ b/rascaline-torch/CMakeLists.txt @@ -95,6 +95,7 @@ set(RASCALINE_TORCH_HEADERS set(RASCALINE_TORCH_SOURCE "src/system.cpp" "src/autograd.cpp" + "src/openmp.cpp" "src/calculator.cpp" "src/register.cpp" ) @@ -133,6 +134,26 @@ generate_export_header(rascaline_torch ) target_compile_definitions(rascaline_torch PRIVATE rascaline_torch_EXPORTS) +find_package(OpenMP) +if (OpenMP_CXX_FOUND) + # Torch bundles its own copy of the OpenMP runtime library, and if we + # compile and link against the system version as well this can lead to + # crashes during initialization on macOS. + # + # So on this plaftorm we instead compile the code with OpenMP flags, and + # leave the corresponding symbols undefined in `rascaline_torch`, hopping + # that when Torch is loaded we'll get these symbols in the global namespace. + # + # On other platforms, this seems to be less of an issue, maybe because torch + # adds a hash to the library name it bundles (i.e. `libgomp-de42aff.so`) + if (APPLE) + string(REPLACE " " ";" omp_cxx_flags_list ${OpenMP_CXX_FLAGS}) + target_compile_options(rascaline_torch PRIVATE ${omp_cxx_flags_list}) + target_link_libraries(rascaline_torch PRIVATE -Wl,-undefined,dynamic_lookup) + else() + target_link_libraries(rascaline_torch PRIVATE OpenMP::OpenMP_CXX) + endif() +endif() if (RASCALINE_TORCH_TESTS) enable_testing() diff --git a/rascaline-torch/src/autograd.cpp b/rascaline-torch/src/autograd.cpp index 65f33e6a3..fc4f3055f 100644 --- a/rascaline-torch/src/autograd.cpp +++ b/rascaline-torch/src/autograd.cpp @@ -3,6 +3,8 @@ #include "metatensor/torch/tensor.hpp" #include "rascaline/torch/autograd.hpp" +#include "./openmp.hpp" + using namespace metatensor_torch; using namespace rascaline_torch; @@ -252,10 +254,6 @@ std::vector PositionsGrad::forward( always_assert(samples->names()[2] == "atom"); // ========================= extract pointers =========================== // - auto dA_dr = torch::zeros_like(all_positions); - always_assert(dA_dr.is_contiguous() && dA_dr.is_cpu()); - auto* dA_dr_ptr = dA_dr.data_ptr(); - // TODO: remove all CPU <=> device data movement by rewriting the VJP // below with torch primitives auto dX_dr_values = dX_dr->values().to(torch::kCPU); @@ -274,24 +272,41 @@ std::vector PositionsGrad::forward( } // =========================== compute dA_dr ============================ // - for (int64_t grad_sample_i=0; grad_sample_icount(); grad_sample_i++) { - auto sample_i = sample_ptr[grad_sample_i * 3 + 0]; - auto system_i = sample_ptr[grad_sample_i * 3 + 1]; - auto atom_i = sample_ptr[grad_sample_i* 3 + 2]; - - auto global_atom_i = systems_start[system_i] + atom_i; - - for (int64_t xyz=0; xyz<3; xyz++) { - auto dot = 0.0; - for (int64_t i=0; i(); + + #pragma omp for + for (int64_t grad_sample_i=0; grad_sample_icount(); grad_sample_i++) { + auto sample_i = sample_ptr[grad_sample_i * 3 + 0]; + auto system_i = sample_ptr[grad_sample_i * 3 + 1]; + auto atom_i = sample_ptr[grad_sample_i * 3 + 2]; + + auto global_atom_i = systems_start[system_i] + atom_i; + + for (int64_t xyz=0; xyz<3; xyz++) { + auto dot = 0.0; + for (int64_t i=0; isave_for_backward({all_positions, dA_dX}); @@ -359,31 +374,42 @@ std::vector PositionsGrad::backward( // ============ gradient of B w.r.t. dA/dX (input of forward) =========== // auto dB_d_dA_dX = torch::Tensor(); if (dA_dX.requires_grad()) { - dB_d_dA_dX = torch::zeros_like(dA_dX_cpu); - always_assert(dB_d_dA_dX.is_contiguous() && dB_d_dA_dX.is_cpu()); - auto* dB_d_dA_dX_ptr = dB_d_dA_dX.data_ptr(); + auto dB_d_dA_dX_multiple = ThreadLocalTensor(); - // dX_dr.shape == [positions gradient samples, 3, features...] - // dB_d_dA_dr.shape == [n_atoms, 3] - // dB_d_dA_dX.shape == [samples, features...] - for (int64_t grad_sample_i=0; grad_sample_icount(); grad_sample_i++) { - auto sample_i = sample_ptr[grad_sample_i * 3 + 0]; - auto system_i = sample_ptr[grad_sample_i * 3 + 1]; - auto atom_i = sample_ptr[grad_sample_i* 3 + 2]; + #pragma omp parallel + { + #pragma omp single + dB_d_dA_dX_multiple.init(omp_get_num_threads(), dA_dX_cpu.sizes(), dA_dX_cpu.options()); - auto global_atom_i = systems_start[system_i] + atom_i; + auto dB_d_dA_dX_local = dB_d_dA_dX_multiple.get(); + always_assert(dB_d_dA_dX_local.is_contiguous() && dB_d_dA_dX_local.is_cpu()); + auto* dB_d_dA_dX_ptr = dB_d_dA_dX_local.data_ptr(); - for (int64_t i=0; icount(); grad_sample_i++) { + auto sample_i = sample_ptr[3 * grad_sample_i + 0]; + auto system_i = sample_ptr[3 * grad_sample_i + 1]; + auto atom_i = sample_ptr[3 * grad_sample_i + 2]; + + auto global_atom_i = systems_start[system_i] + atom_i; + + for (int64_t i=0; i CellGrad::forward( ) { // ====================== input parameters checks ======================= // always_assert(all_cells.requires_grad()); - auto cell_grad = torch::zeros_like(all_cells); - always_assert(cell_grad.is_contiguous() && cell_grad.is_cpu()); - auto* cell_grad_ptr = cell_grad.data_ptr(); auto samples = dX_dH->samples(); const auto* sample_ptr = samples->as_metatensor().values().data(); @@ -438,31 +461,44 @@ std::vector CellGrad::forward( } // =========================== compute dA_dH ============================ // - for (int64_t grad_sample_i=0; grad_sample_icount(); grad_sample_i++) { - auto sample_i = sample_ptr[grad_sample_i]; - // we get the system index from the samples of the values - auto system_i = static_cast(systems[sample_i].item()); + auto dA_dH_multiple = ThreadLocalTensor(); + #pragma omp parallel + { + #pragma omp single + dA_dH_multiple.init(omp_get_num_threads(), all_cells.sizes(), all_cells.options()); - for (int64_t xyz_1=0; xyz_1<3; xyz_1++) { - for (int64_t xyz_2=0; xyz_2<3; xyz_2++) { - auto dot = 0.0; - for (int64_t i=0; i(); + + #pragma omp for + for (int64_t grad_sample_i=0; grad_sample_icount(); grad_sample_i++) { + auto sample_i = sample_ptr[grad_sample_i]; + // we get the system index from the samples of the values + auto system_i = static_cast(systems[sample_i].item()); + + for (int64_t xyz_1=0; xyz_1<3; xyz_1++) { + for (int64_t xyz_2=0; xyz_2<3; xyz_2++) { + auto dot = 0.0; + for (int64_t i=0; isave_for_backward({all_cells, dA_dX, systems}); ctx->saved_data.emplace("cell_gradients", dX_dH); - return {cell_grad}; + return {dA_dH}; } @@ -520,30 +556,41 @@ std::vector CellGrad::backward( // ============ gradient of B w.r.t. dA/dX (input of forward) =========== // auto dB_d_dA_dX = torch::Tensor(); if (dA_dX.requires_grad()) { - dB_d_dA_dX = torch::zeros_like(dA_dX_cpu); - always_assert(dB_d_dA_dX.is_contiguous() && dB_d_dA_dX.is_cpu()); - auto* dB_d_dA_dX_ptr = dB_d_dA_dX.data_ptr(); + auto dB_d_dA_dX_multiple = ThreadLocalTensor(); - // dX_dH.shape == [cell gradient samples, 3, 3, features...] - // dB_d_dA_dH.shape == [systems, 3, 3] - // dB_d_dA_dX.shape == [samples, features...] - for (int64_t grad_sample_i=0; grad_sample_icount(); grad_sample_i++) { - auto sample_i = sample_ptr[grad_sample_i]; - auto system_i = static_cast(systems[sample_i].item()); + #pragma omp parallel + { + #pragma omp single + dB_d_dA_dX_multiple.init(omp_get_num_threads(), dA_dX_cpu.sizes(), dA_dX_cpu.options()); - for (int64_t i=0; i(); + + // dX_dH.shape == [cell gradient samples, 3, 3, features...] + // dB_d_dA_dH.shape == [systems, 3, 3] + // dB_d_dA_dX.shape == [samples, features...] + #pragma omp for + for (int64_t grad_sample_i=0; grad_sample_icount(); grad_sample_i++) { + auto sample_i = sample_ptr[grad_sample_i]; + auto system_i = static_cast(systems[sample_i].item()); - dot += dB_d_dA_dH_ptr[idx_1] * dX_dH_ptr[idx_2 * n_features + i]; + for (int64_t i=0; i + +#include "openmp.hpp" + +#include + + +using namespace rascaline_torch; + +#ifndef _OPENMP + +int omp_get_num_threads() { + return 1; +} + +int omp_get_thread_num() { + return 0; +} + +#endif + +void ThreadLocalTensor::init(int n_threads, at::IntArrayRef size, at::TensorOptions options) { + for (auto i=0; i 0); + auto sum = torch::zeros_like(tensors_[0]); + for (const auto& tensor: tensors_) { + sum += tensor; + } + return sum; +} diff --git a/rascaline-torch/src/openmp.hpp b/rascaline-torch/src/openmp.hpp new file mode 100644 index 000000000..f4d1acf2d --- /dev/null +++ b/rascaline-torch/src/openmp.hpp @@ -0,0 +1,37 @@ +#ifndef RASCALINE_TORCH_OPENMP_HPP +#define RASCALINE_TORCH_OPENMP_HPP + +#include + +#include + +#ifdef _OPENMP +#include +#else + +int omp_get_num_threads(); +int omp_get_thread_num(); + +#endif + +namespace rascaline_torch { + +class ThreadLocalTensor { +public: + /// Zero-initialize all the tensors with the given options + void init(int n_threads, at::IntArrayRef size, at::TensorOptions options = {}); + + /// Get the tensor for the current thread + at::Tensor get(); + + /// Sum all the thread local tensors and return the result + at::Tensor sum(); + +private: + std::vector tensors_; +}; + + +} + +#endif diff --git a/rascaline-torch/tests/calculator.cpp b/rascaline-torch/tests/calculator.cpp index 09dcc499b..2c3078a49 100644 --- a/rascaline-torch/tests/calculator.cpp +++ b/rascaline-torch/tests/calculator.cpp @@ -1,4 +1,3 @@ -#include #include #include