Skip to content

Commit

Permalink
MKL RandIntEngine (pyg-team#222)
Browse files Browse the repository at this point in the history
Add the `generate_range_of_ints` function, which uses MKL to generate integer numbers within a given range.

Results obtained using the neighbor.py benchmark showed an average speedup equal to 1.27x.
---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Damian Szwichtenberg <[email protected]>
  • Loading branch information
4 people authored Aug 16, 2023
1 parent 6e9e3b8 commit 4f34071
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 28 deletions.
23 changes: 16 additions & 7 deletions .github/workflows/cpp_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ jobs:

gtest:
runs-on: ubuntu-latest
env:
MKL_VERSION: 2023.1.0

steps:
- name: Checkout repostiory
Expand All @@ -22,32 +24,39 @@ jobs:

- name: Configure
run: |
pip install mkl-include==2021.4
pip install mkl==${MKL_VERSION} mkl-include==${MKL_VERSION} mkl-devel==${MKL_VERSION}
export _BLAS_INCLUDE_DIR=`python -c 'import os;import sysconfig;data=sysconfig.get_path("data");print(f"{data}{os.sep}include")'`
export LIBS_DIR=`python -c 'import os;import sysconfig;data=sysconfig.get_path("data");print(f"{data}{os.sep}lib")'`
export MKL_DIR=`python -c 'import os;import sysconfig;data=sysconfig.get_path("data");print(f"{data}{os.sep}lib{os.sep}cmake{os.sep}mkl")'`
cd $LIBS_DIR
for library in `ls *.so.2`; do
ln -s ${library} ${library::-2} || true
done
cd -
mkdir build
cd build
export _BLAS_INCLUDE_DIR=`python -c 'import os;import sysconfig;data=sysconfig.get_path("data");print(f"{data}{os.sep}include")'`
Torch_DIR=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` cmake .. -GNinja -DBUILD_TEST=ON -DWITH_COV=ON -DCMAKE_BUILD_TYPE=DEBUG -DUSE_MKL_BLAS=ON -DBLAS_INCLUDE_DIR=$_BLAS_INCLUDE_DIR
Torch_DIR=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` cmake .. -GNinja -DBUILD_TEST=ON -DWITH_COV=ON -DCMAKE_BUILD_TYPE=DEBUG -DUSE_MKL_BLAS=ON -DBLAS_INCLUDE_DIR=$_BLAS_INCLUDE_DIR -DMKL_DIR=${MKL_DIR}
unset _BLAS_INCLUDE_DIR
cd ..
- name: Build
run: |
cd build
cmake --build .
cd ..
- name: Run tests
run: |
cd build
ctest --verbose --output-on-failure
cd ..
- name: Collect coverage
run: |
sudo apt-get install lcov
lcov --directory . --capture --output-file .coverage.info
lcov --remove .coverage.info '*/test/*' --output-file .coverage.info
- name: Upload coverage
uses: codecov/codecov-action@v2
with:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed performance issues reported by Coverity Tool ([#240](https://github.com/pyg-team/pyg-lib/pull/240))
- Updated `cutlass` version for speed boosts in `segment_matmul` and `grouped_matmul` ([#235](https://github.com/pyg-team/pyg-lib/pull/235))
- Drop nested tensor wrapper for `grouped_matmul` implementation ([#226](https://github.com/pyg-team/pyg-lib/pull/226))
- Added `generate_range_of_ints` function (it uses MKL library in order to generate ints) to RandintEngine class ([#222](https://github.com/pyg-team/pyg-lib/pull/222))
- Fixed TorchScript support in `grouped_matmul` ([#220](https://github.com/pyg-team/pyg-lib/pull/220))
### Removed

Expand Down
37 changes: 37 additions & 0 deletions benchmark/csrc/random/rand_engine.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <vector>

#include <benchmark/benchmark.h>

#include <pyg_lib/csrc/random/cpu/rand_engine.h>

constexpr int64_t beg = 0;
constexpr int64_t end = 1 << 15;

void BenchmarkRandEngineWithMKL(benchmark::State& state) {
const int64_t count = state.range(0);
pyg::random::RandintEngine<int64_t> generator;

for (auto _ : state) {
const auto out =
std::move(generator.generate_range_of_ints(beg, end, count));
benchmark::DoNotOptimize(out);
}
}
BENCHMARK(BenchmarkRandEngineWithMKL)
->RangeMultiplier(2)
->Range(1 << 8, 1 << 12);

void BenchmarkRandEngineWithPrefetching(benchmark::State& state) {
const int64_t count = state.range(0);
pyg::random::RandintEngine<int64_t> generator;

for (auto _ : state) {
for (int64_t i = 0; i < count; ++i) {
const auto out = generator(beg, end);
benchmark::DoNotOptimize(out);
}
}
}
BENCHMARK(BenchmarkRandEngineWithPrefetching)
->RangeMultiplier(2)
->Range(1 << 8, 1 << 12);
3 changes: 3 additions & 0 deletions cmake/benchmark.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@ foreach(benchmark ${ALL_BENCHMARKS})
get_filename_component(name ${benchmark} NAME_WE)
add_executable(${name} ${benchmark})
target_link_libraries(${name} ${PROJECT_NAME} benchmark::benchmark benchmark::benchmark_main torch)
if(MKL_INCLUDE_FOUND)
target_include_directories(${name} PRIVATE ${BLAS_INCLUDE_DIR})
endif()
target_include_directories(${name} PRIVATE ${PHMAP_DIR})
endforeach()
5 changes: 4 additions & 1 deletion cmake/test.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ file(GLOB_RECURSE ALL_TESTS ${CTEST}/*.cpp)
foreach(test ${ALL_TESTS})
get_filename_component(name ${test} NAME_WE)
add_executable(${name} ${test})
target_link_libraries(${name} ${PROJECT_NAME} gtest_main torch)
target_link_libraries(${name} ${PROJECT_NAME} gtest_main torch ${TORCH_LIBRARIES})
if(MKL_INCLUDE_FOUND)
target_include_directories(${name} PRIVATE ${BLAS_INCLUDE_DIR})
endif()
gtest_discover_tests(${name})
endforeach()
50 changes: 47 additions & 3 deletions pyg_lib/csrc/random/cpu/rand_engine.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#pragma once

#include <ATen/ATen.h>

#include <limits.h>

#include "pyg_lib/csrc/config.h"
#if WITH_MKL_BLAS()
#include <mkl.h>
#endif

namespace pyg {
namespace random {

Expand Down Expand Up @@ -99,18 +103,58 @@ class PrefetchedRandint {
template <typename T>
class RandintEngine {
public:
RandintEngine() : prefetched_(RAND_PREFETCH_SIZE, RAND_PREFETCH_BITS) {}
RandintEngine() {
#if WITH_MKL_BLAS()
vslNewStream(&stream_, VSL_BRNG_MT19937, 1);
#endif
}
~RandintEngine() {
#if WITH_MKL_BLAS()
vslDeleteStream(&stream_);
#endif
}

// Uniform random number within range [beg, end)
T operator()(T beg, T end) {
TORCH_CHECK(beg < end, "Randint engine illegal range");

T range = end - beg;
const T range = end - beg;
return prefetched_.next(range) + beg;
}

// Generates `count` numbers within range [beg, end). If
// possible, it will use specialized MKL implementation.
// It is user's responsibility to ensure that `beg` and `end`
// fit into int.
std::vector<int> generate_range_of_ints(T beg, T end, int64_t count) {
TORCH_CHECK(beg < end, "Randint engine illegal range");

std::vector<int> result(count);
const auto fallback_func = [this](T beg, T end, std::vector<int>& dst) {
for (auto& val : dst)
val = static_cast<int>((*this)(beg, end));
};
#if WITH_MKL_BLAS()
const bool use_fallback_func = count > std::numeric_limits<MKL_INT>::max();
if (use_fallback_func) {
fallback_func(beg, end, result);
} else {
const auto b = static_cast<int>(beg);
const auto e = static_cast<int>(end);
const auto c = static_cast<MKL_INT>(count);
viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream_, c, result.data(), b, e);
}
#else
fallback_func(beg, end, result);
#endif
return result;
}

private:
PrefetchedRandint prefetched_;
#if WITH_MKL_BLAS()
VSLStreamStatePtr stream_;
#endif
};

/**
Expand Down
47 changes: 35 additions & 12 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,48 @@ class NeighborSampler {

// Case 2: Sample with replacement:
else if (replace) {
for (size_t i = 0; i < count; ++i) {
const auto edge_id = generator(row_start, row_end);
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
if (row_end < (1 << 16)) {
const auto arr = std::move(
generator.generate_range_of_ints(row_start, row_end, count));
for (const auto edge_id : arr)
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
} else {
for (int64_t i = 0; i < count; ++i) {
const auto edge_id = generator(row_start, row_end);
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}
}

// Case 3: Sample without replacement:
else {
auto index_tracker = IndexTracker<scalar_t>(population);
for (size_t i = population - count; i < population; ++i) {
auto rnd = generator(0, i + 1);
if (!index_tracker.try_insert(rnd)) {
rnd = i;
index_tracker.insert(i);
if (population < (1 << 16)) {
const auto arr =
std::move(generator.generate_range_of_ints(0, population, count));
for (auto i = 0; i < arr.size(); ++i) {
auto rnd = arr[i];
if (!index_tracker.try_insert(rnd)) {
rnd = population - count + i;
index_tracker.insert(population - count + i);
}
const auto edge_id = row_start + rnd;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
} else {
for (auto i = population - count; i < population; ++i) {
auto rnd = generator(0, i + 1);
if (!index_tracker.try_insert(rnd)) {
rnd = i;
index_tracker.insert(i);
}
const auto edge_id = row_start + rnd;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
const auto edge_id = row_start + rnd;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}
}
Expand Down
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,15 @@ def maybe_append_with_mkl(dependencies):
torch_config = torch.__config__.show()
with_mkl_blas = 'BLAS_INFO=mkl' in torch_config
if torch.backends.mkl.is_available() and with_mkl_blas:
# product version is decoupled from library version. For older
# releases, where MKL was not a part of oneAPI, we can safely use
# 2021.4 as it is backward compatible (default for PyTorch conda
# distribution).
product_version = '2021.4'
product_version = '2023.1.0'
pattern = r'oneAPI Math Kernel Library Version [0-9]{4}\.[0-9]+'
match = re.search(pattern, torch_config)
if match:
product_version = match.group(0).split(' ')[-1]

dependencies.append(f'mkl-include=={product_version}')
dependencies.append(f'mkl=={product_version}')
dependencies.append(f'mkl-devel=={product_version}')


install_requires = []
Expand Down

0 comments on commit 4f34071

Please sign in to comment.