-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
217 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#include <cassert> | ||
|
||
#include "openmp.hpp" | ||
|
||
#include <torch/torch.h> | ||
|
||
|
||
using namespace rascaline_torch; | ||
|
||
#ifndef _OPENMP | ||
|
||
int omp_get_num_threads() { | ||
return 1; | ||
} | ||
|
||
int omp_get_thread_num() { | ||
return 0; | ||
} | ||
|
||
#endif | ||
|
||
void ThreadLocalTensor::init(int n_threads, at::IntArrayRef size, at::TensorOptions options) { | ||
for (auto i=0; i<n_threads; i++) { | ||
tensors_.emplace_back(torch::zeros(size, options)); | ||
} | ||
} | ||
|
||
at::Tensor ThreadLocalTensor::get() { | ||
return tensors_.at(omp_get_thread_num()); | ||
} | ||
|
||
at::Tensor ThreadLocalTensor::sum() { | ||
assert(tensors_.size() > 0); | ||
auto sum = torch::zeros_like(tensors_[0]); | ||
for (const auto& tensor: tensors_) { | ||
sum += tensor; | ||
} | ||
return sum; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#ifndef RASCALINE_TORCH_OPENMP_HPP | ||
#define RASCALINE_TORCH_OPENMP_HPP | ||
|
||
#include <vector> | ||
|
||
#include <ATen/Tensor.h> | ||
|
||
#ifdef _OPENMP | ||
#include <omp.h> | ||
#else | ||
|
||
int omp_get_num_threads(); | ||
int omp_get_thread_num(); | ||
|
||
#endif | ||
|
||
namespace rascaline_torch { | ||
|
||
class ThreadLocalTensor { | ||
public: | ||
/// Zero-initialize all the tensors with the given options | ||
void init(int n_threads, at::IntArrayRef size, at::TensorOptions options = {}); | ||
|
||
/// Get the tensor for the current thread | ||
at::Tensor get(); | ||
|
||
/// Sum all the thread local tensors and return the result | ||
at::Tensor sum(); | ||
|
||
private: | ||
std::vector<at::Tensor> tensors_; | ||
}; | ||
|
||
|
||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
#include <c10/util/intrusive_ptr.h> | ||
#include <torch/torch.h> | ||
|
||
#include <rascaline.hpp> | ||
|