Skip to content

Commit

Permalink
Use OpenMP to parallelize backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Jun 17, 2024
1 parent 4416615 commit 91418c6
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 74 deletions.
21 changes: 21 additions & 0 deletions rascaline-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ set(RASCALINE_TORCH_HEADERS
set(RASCALINE_TORCH_SOURCE
"src/system.cpp"
"src/autograd.cpp"
"src/openmp.cpp"
"src/calculator.cpp"
"src/register.cpp"
)
Expand Down Expand Up @@ -133,6 +134,26 @@ generate_export_header(rascaline_torch
)
target_compile_definitions(rascaline_torch PRIVATE rascaline_torch_EXPORTS)

find_package(OpenMP)
if (OpenMP_CXX_FOUND)
# Torch bundles its own copy of the OpenMP runtime library, and if we
# compile and link against the system version as well this can lead to
# crashes during initialization on macOS.
#
# So on this plaftorm we instead compile the code with OpenMP flags, and
# leave the corresponding symbols undefined in `rascaline_torch`, hopping
# that when Torch is loaded we'll get these symbols in the global namespace.
#
# On other platforms, this seems to be less of an issue, maybe because torch
# adds a hash to the library name it bundles (i.e. `libgomp-de42aff.so`)
if (APPLE)
string(REPLACE " " ";" omp_cxx_flags_list ${OpenMP_CXX_FLAGS})
target_compile_options(rascaline_torch PRIVATE ${omp_cxx_flags_list})
target_link_libraries(rascaline_torch PRIVATE -Wl,-undefined,dynamic_lookup)
else()
target_link_libraries(rascaline_torch PRIVATE OpenMP::OpenMP_CXX)
endif()
endif()

if (RASCALINE_TORCH_TESTS)
enable_testing()
Expand Down
193 changes: 120 additions & 73 deletions rascaline-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "metatensor/torch/tensor.hpp"
#include "rascaline/torch/autograd.hpp"

#include "./openmp.hpp"

using namespace metatensor_torch;
using namespace rascaline_torch;

Expand Down Expand Up @@ -252,10 +254,6 @@ std::vector<torch::Tensor> PositionsGrad<scalar_t>::forward(
always_assert(samples->names()[2] == "atom");

// ========================= extract pointers =========================== //
auto dA_dr = torch::zeros_like(all_positions);
always_assert(dA_dr.is_contiguous() && dA_dr.is_cpu());
auto* dA_dr_ptr = dA_dr.data_ptr<scalar_t>();

// TODO: remove all CPU <=> device data movement by rewriting the VJP
// below with torch primitives
auto dX_dr_values = dX_dr->values().to(torch::kCPU);
Expand All @@ -274,24 +272,41 @@ std::vector<torch::Tensor> PositionsGrad<scalar_t>::forward(
}

// =========================== compute dA_dr ============================ //
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[grad_sample_i * 3 + 0];
auto system_i = sample_ptr[grad_sample_i * 3 + 1];
auto atom_i = sample_ptr[grad_sample_i* 3 + 2];

auto global_atom_i = systems_start[system_i] + atom_i;

for (int64_t xyz=0; xyz<3; xyz++) {
auto dot = 0.0;
for (int64_t i=0; i<n_features; i++) {
dot += (
dX_dr_ptr[(grad_sample_i * 3 + xyz) * n_features + i]
* dA_dX_ptr[sample_i * n_features + i]
);
// For OpenMP parallelization, we allocate a temporary output on each thread
// with ThreadLocalTensor, then let each thread write to their own copy &
// finally sum each of the thread local results.
auto dA_dr_multiple = ThreadLocalTensor();
#pragma omp parallel
{
#pragma omp single
dA_dr_multiple.init(omp_get_num_threads(), all_positions.sizes(), all_positions.options());

auto dA_dr_local = dA_dr_multiple.get();
always_assert(dA_dr_local.is_contiguous() && dA_dr_local.is_cpu());
auto dA_dr_ptr = dA_dr_local.data_ptr<scalar_t>();

#pragma omp for
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[grad_sample_i * 3 + 0];
auto system_i = sample_ptr[grad_sample_i * 3 + 1];
auto atom_i = sample_ptr[grad_sample_i * 3 + 2];

auto global_atom_i = systems_start[system_i] + atom_i;

for (int64_t xyz=0; xyz<3; xyz++) {
auto dot = 0.0;
for (int64_t i=0; i<n_features; i++) {
dot += (
dX_dr_ptr[(grad_sample_i * 3 + xyz) * n_features + i]
* dA_dX_ptr[sample_i * n_features + i]
);
}
dA_dr_ptr[global_atom_i * 3 + xyz] += dot;
}
dA_dr_ptr[global_atom_i * 3 + xyz] += dot;
}
}
auto dA_dr = dA_dr_multiple.sum();


// ===================== data for double backward ======================= //
ctx->save_for_backward({all_positions, dA_dX});
Expand Down Expand Up @@ -359,31 +374,42 @@ std::vector<torch::Tensor> PositionsGrad<scalar_t>::backward(
// ============ gradient of B w.r.t. dA/dX (input of forward) =========== //
auto dB_d_dA_dX = torch::Tensor();
if (dA_dX.requires_grad()) {
dB_d_dA_dX = torch::zeros_like(dA_dX_cpu);
always_assert(dB_d_dA_dX.is_contiguous() && dB_d_dA_dX.is_cpu());
auto* dB_d_dA_dX_ptr = dB_d_dA_dX.data_ptr<scalar_t>();
auto dB_d_dA_dX_multiple = ThreadLocalTensor();

// dX_dr.shape == [positions gradient samples, 3, features...]
// dB_d_dA_dr.shape == [n_atoms, 3]
// dB_d_dA_dX.shape == [samples, features...]
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[grad_sample_i * 3 + 0];
auto system_i = sample_ptr[grad_sample_i * 3 + 1];
auto atom_i = sample_ptr[grad_sample_i* 3 + 2];
#pragma omp parallel
{
#pragma omp single
dB_d_dA_dX_multiple.init(omp_get_num_threads(), dA_dX_cpu.sizes(), dA_dX_cpu.options());

auto global_atom_i = systems_start[system_i] + atom_i;
auto dB_d_dA_dX_local = dB_d_dA_dX_multiple.get();
always_assert(dB_d_dA_dX_local.is_contiguous() && dB_d_dA_dX_local.is_cpu());
auto* dB_d_dA_dX_ptr = dB_d_dA_dX_local.data_ptr<scalar_t>();

for (int64_t i=0; i<n_features; i++) {
auto dot = 0.0;
for (int64_t xyz=0; xyz<3; xyz++) {
dot += (
dX_dr_ptr[(grad_sample_i * 3 + xyz) * n_features + i]
* dB_d_dA_dr_ptr[global_atom_i * 3 + xyz]
);
// dX_dr.shape == [positions gradient samples, 3, features...]
// dB_d_dA_dr.shape == [n_atoms, 3]
// dB_d_dA_dX.shape == [samples, features...]
#pragma omp for
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[3 * grad_sample_i + 0];
auto system_i = sample_ptr[3 * grad_sample_i + 1];
auto atom_i = sample_ptr[3 * grad_sample_i + 2];

auto global_atom_i = systems_start[system_i] + atom_i;

for (int64_t i=0; i<n_features; i++) {
auto dot = 0.0;
for (int64_t xyz=0; xyz<3; xyz++) {
dot += (
dX_dr_ptr[(grad_sample_i * 3 + xyz) * n_features + i]
* dB_d_dA_dr_ptr[global_atom_i * 3 + xyz]
);
}
dB_d_dA_dX_ptr[sample_i * n_features + i] += dot;
}
dB_d_dA_dX_ptr[sample_i * n_features + i] += dot;
}
}

dB_d_dA_dX = dB_d_dA_dX_multiple.sum();
}

return {
Expand All @@ -409,9 +435,6 @@ std::vector<torch::Tensor> CellGrad<scalar_t>::forward(
) {
// ====================== input parameters checks ======================= //
always_assert(all_cells.requires_grad());
auto cell_grad = torch::zeros_like(all_cells);
always_assert(cell_grad.is_contiguous() && cell_grad.is_cpu());
auto* cell_grad_ptr = cell_grad.data_ptr<scalar_t>();

auto samples = dX_dH->samples();
const auto* sample_ptr = samples->as_metatensor().values().data();
Expand All @@ -438,31 +461,44 @@ std::vector<torch::Tensor> CellGrad<scalar_t>::forward(
}

// =========================== compute dA_dH ============================ //
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[grad_sample_i];
// we get the system index from the samples of the values
auto system_i = static_cast<int64_t>(systems[sample_i].item<int32_t>());
auto dA_dH_multiple = ThreadLocalTensor();
#pragma omp parallel
{
#pragma omp single
dA_dH_multiple.init(omp_get_num_threads(), all_cells.sizes(), all_cells.options());

for (int64_t xyz_1=0; xyz_1<3; xyz_1++) {
for (int64_t xyz_2=0; xyz_2<3; xyz_2++) {
auto dot = 0.0;
for (int64_t i=0; i<n_features; i++) {
auto sample_component_row = (grad_sample_i * 3 + xyz_1) * 3 + xyz_2;
dot += (
dA_dX_ptr[sample_i * n_features + i]
* dX_dH_ptr[sample_component_row * n_features + i]
);
auto dA_dH_local = dA_dH_multiple.get();
always_assert(dA_dH_local.is_contiguous() && dA_dH_local.is_cpu());
auto dA_dH_ptr = dA_dH_local.data_ptr<scalar_t>();

#pragma omp for
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[grad_sample_i];
// we get the system index from the samples of the values
auto system_i = static_cast<int64_t>(systems[sample_i].item<int32_t>());

for (int64_t xyz_1=0; xyz_1<3; xyz_1++) {
for (int64_t xyz_2=0; xyz_2<3; xyz_2++) {
auto dot = 0.0;
for (int64_t i=0; i<n_features; i++) {
auto sample_component_row = (grad_sample_i * 3 + xyz_1) * 3 + xyz_2;
dot += (
dA_dX_ptr[sample_i * n_features + i]
* dX_dH_ptr[sample_component_row * n_features + i]
);
}
dA_dH_ptr[(system_i * 3 + xyz_1) * 3 + xyz_2] += dot;
}
cell_grad_ptr[(system_i * 3 + xyz_1) * 3 + xyz_2] += dot;
}
}
}
auto dA_dH = dA_dH_multiple.sum();

// ===================== data for double backward ======================= //
ctx->save_for_backward({all_cells, dA_dX, systems});
ctx->saved_data.emplace("cell_gradients", dX_dH);

return {cell_grad};
return {dA_dH};
}


Expand Down Expand Up @@ -520,30 +556,41 @@ std::vector<torch::Tensor> CellGrad<scalar_t>::backward(
// ============ gradient of B w.r.t. dA/dX (input of forward) =========== //
auto dB_d_dA_dX = torch::Tensor();
if (dA_dX.requires_grad()) {
dB_d_dA_dX = torch::zeros_like(dA_dX_cpu);
always_assert(dB_d_dA_dX.is_contiguous() && dB_d_dA_dX.is_cpu());
auto* dB_d_dA_dX_ptr = dB_d_dA_dX.data_ptr<scalar_t>();
auto dB_d_dA_dX_multiple = ThreadLocalTensor();

// dX_dH.shape == [cell gradient samples, 3, 3, features...]
// dB_d_dA_dH.shape == [systems, 3, 3]
// dB_d_dA_dX.shape == [samples, features...]
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[grad_sample_i];
auto system_i = static_cast<int64_t>(systems[sample_i].item<int32_t>());
#pragma omp parallel
{
#pragma omp single
dB_d_dA_dX_multiple.init(omp_get_num_threads(), dA_dX_cpu.sizes(), dA_dX_cpu.options());

for (int64_t i=0; i<n_features; i++) {
auto dot = 0.0;
for (int64_t xyz_1=0; xyz_1<3; xyz_1++) {
for (int64_t xyz_2=0; xyz_2<3; xyz_2++) {
auto idx_1 = (system_i * 3 + xyz_1) * 3 + xyz_2;
auto idx_2 = (grad_sample_i * 3 + xyz_1) * 3 + xyz_2;
auto dB_d_dA_dX_local = dB_d_dA_dX_multiple.get();
always_assert(dB_d_dA_dX_local.is_contiguous() && dB_d_dA_dX_local.is_cpu());
auto* dB_d_dA_dX_ptr = dB_d_dA_dX_local.data_ptr<scalar_t>();

// dX_dH.shape == [cell gradient samples, 3, 3, features...]
// dB_d_dA_dH.shape == [systems, 3, 3]
// dB_d_dA_dX.shape == [samples, features...]
#pragma omp for
for (int64_t grad_sample_i=0; grad_sample_i<samples->count(); grad_sample_i++) {
auto sample_i = sample_ptr[grad_sample_i];
auto system_i = static_cast<int64_t>(systems[sample_i].item<int32_t>());

dot += dB_d_dA_dH_ptr[idx_1] * dX_dH_ptr[idx_2 * n_features + i];
for (int64_t i=0; i<n_features; i++) {
auto dot = 0.0;
for (int64_t xyz_1=0; xyz_1<3; xyz_1++) {
for (int64_t xyz_2=0; xyz_2<3; xyz_2++) {
auto idx_1 = (system_i * 3 + xyz_1) * 3 + xyz_2;
auto idx_2 = (grad_sample_i * 3 + xyz_1) * 3 + xyz_2;

dot += dB_d_dA_dH_ptr[idx_1] * dX_dH_ptr[idx_2 * n_features + i];
}
}
dB_d_dA_dX_ptr[sample_i * n_features + i] += dot;
}
dB_d_dA_dX_ptr[sample_i * n_features + i] += dot;
}
}

dB_d_dA_dX = dB_d_dA_dX_multiple.sum();
}

return {
Expand Down
39 changes: 39 additions & 0 deletions rascaline-torch/src/openmp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <cassert>

#include "openmp.hpp"

#include <torch/torch.h>


using namespace rascaline_torch;

#ifndef _OPENMP

int omp_get_num_threads() {
return 1;
}

int omp_get_thread_num() {
return 0;
}

#endif

void ThreadLocalTensor::init(int n_threads, at::IntArrayRef size, at::TensorOptions options) {
for (auto i=0; i<n_threads; i++) {
tensors_.emplace_back(torch::zeros(size, options));
}
}

at::Tensor ThreadLocalTensor::get() {
return tensors_.at(omp_get_thread_num());
}

at::Tensor ThreadLocalTensor::sum() {
assert(tensors_.size() > 0);
auto sum = torch::zeros_like(tensors_[0]);
for (const auto& tensor: tensors_) {
sum += tensor;
}
return sum;
}
37 changes: 37 additions & 0 deletions rascaline-torch/src/openmp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef RASCALINE_TORCH_OPENMP_HPP
#define RASCALINE_TORCH_OPENMP_HPP

#include <vector>

#include <ATen/Tensor.h>

#ifdef _OPENMP
#include <omp.h>
#else

int omp_get_num_threads();
int omp_get_thread_num();

#endif

namespace rascaline_torch {

class ThreadLocalTensor {
public:
/// Zero-initialize all the tensors with the given options
void init(int n_threads, at::IntArrayRef size, at::TensorOptions options = {});

/// Get the tensor for the current thread
at::Tensor get();

/// Sum all the thread local tensors and return the result
at::Tensor sum();

private:
std::vector<at::Tensor> tensors_;
};


}

#endif
1 change: 0 additions & 1 deletion rascaline-torch/tests/calculator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <c10/util/intrusive_ptr.h>
#include <torch/torch.h>

#include <rascaline.hpp>
Expand Down

0 comments on commit 91418c6

Please sign in to comment.