Skip to content

Commit

Permalink
Merge pull request #150 from lukasm91/add_fft_ungrouped
Browse files Browse the repository at this point in the history
Add FFT implementation without graphs
  • Loading branch information
samhatfield authored Sep 20, 2024
2 parents 4a86858 + db24ae0 commit f6b40d6
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 108 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ ecbuild_add_option( FEATURE GPU_GRAPHS_GEMM
CONDITION HAVE_GPU
DESCRIPTION "Enable graph-based optimisation of Legendre transform GEMM kernel" )

ecbuild_add_option( FEATURE GPU_GRAPHS_FFT
DEFAULT ON
CONDITION HAVE_GPU
DESCRIPTION "Enable graph-based optimisation of FFT kernels" )

if( BUILD_SHARED_LIBS )
set( GPU_STATIC_DEFAULT OFF )
else()
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Extra options can be added to the `cmake` command to control the build:
Specific extra options exist for GPU installation:
- `-DENABLE_GPU_AWARE_MPI=<ON|OFF>` default=OF
- `-DENABLE_GPU_GRAPHS_GEMM=<ON|OFF>` default=ON
- `-DENABLE_GPU_GRAPHS_FFT=<ON|OFF>` default=ON
- `-DENABLE_CUTLASS=<ON|OFF>` default=OFF
- `-DENABLE_3XTF32=<ON|OFF>` default=OFF

Expand Down
2 changes: 2 additions & 0 deletions src/trans/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ ecbuild_add_library(
$<${HAVE_CUTLASS}:USE_CUTLASS>
$<${HAVE_CUTLASS_3XTF32}:USE_CUTLASS_3XTF32>
$<${HAVE_GPU_GRAPHS_GEMM}:USE_GRAPHS_GEMM>
$<${HAVE_GPU_GRAPHS_FFT}:USE_GRAPHS_FFT>
)

ectrans_target_fortran_module_directory(
Expand Down Expand Up @@ -150,6 +151,7 @@ foreach( prec dp sp )
$<${HAVE_CUTLASS}:USE_CUTLASS>
$<${HAVE_CUTLASS_3XTF32}:USE_CUTLASS_3XTF32>
$<${HAVE_GPU_GRAPHS_GEMM}:USE_GRAPHS_GEMM>
$<${HAVE_GPU_GRAPHS_FFT}:USE_GRAPHS_FFT>
$<${HAVE_GPU_AWARE_MPI}:USE_GPU_AWARE_MPI>
ECTRANS_HAVE_MPI=${ectrans_HAVE_MPI}
)
Expand Down
249 changes: 141 additions & 108 deletions src/trans/gpu/algor/hicfft.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,38 @@ struct Float {
using cmplx = hipfftComplex;
};

template <class Type, hipfftType Direction> class hicfft_plan {
using real = typename Type::real;
using cmplx = typename Type::cmplx;

public:
void exec(real *data_real, cmplx *data_complex) const {
real *data_real_l = &data_real[offset];
cmplx *data_complex_l = &data_complex[offset / 2];
if constexpr (Direction == HIPFFT_R2C)
fftSafeCall(hipfftExecR2C(handle, data_real_l, data_complex_l));
else if constexpr (Direction == HIPFFT_C2R)
fftSafeCall(hipfftExecC2R(handle, data_complex_l, data_real_l));
else if constexpr (Direction == HIPFFT_D2Z)
fftSafeCall(hipfftExecD2Z(handle, data_real_l, data_complex_l));
else if constexpr (Direction == HIPFFT_Z2D)
fftSafeCall(hipfftExecZ2D(handle, data_complex_l, data_real_l));
}
void set_stream(hipStream_t stream) {
fftSafeCall(hipfftSetStream(handle, stream));
}
hicfft_plan(hipfftHandle handle_, int offset_)
: handle(handle_), offset(offset_) {}

private:
hipfftHandle handle;
int offset;
};

// kfield -> handles
template <class Type, hipfftType Direction> auto &get_fft_plan_cache() {
static std::unordered_map<int, std::vector<hipfftHandle>> fftPlansCache;
static std::unordered_map<int, std::vector<hicfft_plan<Type, Direction>>>
fftPlansCache;
return fftPlansCache;
}
// kfield -> graphs
Expand All @@ -58,137 +87,141 @@ void free_fft_cache(float *, size_t) {
get_ptr_cache<Type, Direction>().clear();
}


template <class Type, hipfftType Direction>
std::vector<hicfft_plan<Type, Direction>> plan_all(int kfield, int *loens,
int nfft, int *offsets) {
static constexpr bool is_forward =
Direction == HIPFFT_R2C || Direction == HIPFFT_D2Z;

auto key = kfield;
auto &fftPlansCache = get_fft_plan_cache<Type, Direction>();
auto fftPlans = fftPlansCache.find(key);
if (fftPlans == fftPlansCache.end()) {
// the fft plans do not exist yet
std::vector<hicfft_plan<Type, Direction>> newPlans;
newPlans.reserve(nfft);
for (int i = 0; i < nfft; ++i) {
int nloen = loens[i];

hipfftHandle plan;
fftSafeCall(hipfftCreate(&plan));
int dist = offsets[i + 1] - offsets[i];
int embed[] = {1};
fftSafeCall(hipfftPlanMany(
&plan, 1, &nloen, embed, 1, is_forward ? dist : dist / 2, embed, 1,
is_forward ? dist / 2 : dist, Direction, kfield));
newPlans.emplace_back(plan, kfield * offsets[i]);
}
fftPlansCache.insert({key, newPlans});
}
return fftPlansCache.find(key)->second;
}

template <class Type, hipfftType Direction>
void execute_fft_new(typename Type::real *data_real, typename Type::cmplx *data_complex,
int kfield, int *loens, int *offsets, int nfft, void *growing_allocator) {
void run_group_graph(typename Type::real *data_real,
typename Type::cmplx *data_complex, int kfield, int *loens,
int *offsets, int nfft, void *growing_allocator) {

growing_allocator_register_free_c(growing_allocator,
free_fft_cache<Type, Direction>);

constexpr bool is_forward = Direction == HIPFFT_R2C || Direction == HIPFFT_D2Z;
using real = typename Type::real;
using cmplx = typename Type::cmplx;

/* static std::unordered_map<int, void *> allocationCache; // nloens -> ptr */
//* static std::unordered_map<int, std::vector<hipfftHandle>> fftPlansCache; // kfield -> handles
//* static std::unordered_map<int, hipGraphExec_t> graphCache; // kfield -> graphs

// if the pointers are changed, we need to update the graph
//* static std::unordered_map<int, std::pair<real *, cmplx *>> ptrCache; // kfield -> ptrs
auto &ptrCache = get_ptr_cache<Type, Direction>(); // kfield -> ptrs
auto &graphCache = get_graph_cache<Type, Direction>(); // kfield -> graphs

auto ptrs = ptrCache.find(kfield);
if (ptrs != ptrCache.end() && (
ptrs->second.first != data_real || ptrs->second.second != data_complex)) {
// the plan is cached, but the pointers are not correct. we remove and delete the graph,
// but we keep the FFT plans, if this happens more often, we should cache this...
std::cout << "WARNING FFT: POINTER CHANGE --> THIS MIGHT BE SLOW"
<< std::endl;
HIC_CHECK(hipGraphExecDestroy(graphCache[kfield]));
graphCache.erase(kfield);
ptrCache.erase(kfield);
auto key = kfield;
auto ptrs = ptrCache.find(key);
if (ptrs != ptrCache.end() && (ptrs->second.first != data_real ||
ptrs->second.second != data_complex)) {
// the plan is cached, but the pointers are not correct. we remove and
// delete the graph, but we keep the FFT plans, if this happens more often,
// we should cache this...
std::cout << "WARNING FFT: POINTER CHANGE --> THIS MIGHT BE SLOW"
<< std::endl;
HIC_CHECK(hipGraphExecDestroy(graphCache[key]));
graphCache.erase(key);
ptrCache.erase(key);
}

//* auto &fftPlansCache =
//* get_fft_plan_cache<Type, Direction>(); // kfield -> handles
auto graph = graphCache.find(kfield);
auto graph = graphCache.find(key);
if (graph == graphCache.end()) {
// this graph does not exist yet

auto &fftPlansCache =
get_fft_plan_cache<Type, Direction>(); // kfield -> handles
auto fftPlans = fftPlansCache.find(kfield);
if (fftPlans == fftPlansCache.end()) {
// the fft plans do not exist yet
std::vector<hipfftHandle> newPlans;
newPlans.resize(nfft);
for (int i = 0; i < nfft; ++i) {
int nloen = loens[i];

hipfftHandle plan;
fftSafeCall(hipfftCreate(&plan));
int dist = offsets[i+1] - offsets[i];
int embed[] = {1};
//fftSafeCall(hipfftPlanMany(&plan, 1, &nloen, embed, 1, dist, embed,
// 1, dist / 2, Direction, kfield));
fftSafeCall(hipfftPlanMany(&plan, 1, &nloen, embed, 1, is_forward ? dist : dist / 2, embed,
1, is_forward ? dist / 2 : dist, Direction, kfield));
newPlans[i] = plan;
}
fftPlansCache.insert({kfield, newPlans});
}
fftPlans = fftPlansCache.find(kfield);

// create a temporary stream
hipStream_t stream;
HIC_CHECK(hipStreamCreate(&stream));

for (auto &plan : fftPlans->second) // set the streams
fftSafeCall(hipfftSetStream(plan, stream));

// now create the graph
hipGraph_t new_graph;
hipGraphCreate(&new_graph, 0);
for (int i = 0; i < nfft; ++i) {
int offset = offsets[i];
real *data_real_l = &data_real[kfield * offset];
cmplx *data_complex_l = &data_complex[kfield * offset / 2];
HIC_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
if constexpr(Direction == HIPFFT_R2C)
fftSafeCall(hipfftExecR2C(fftPlans->second[i], data_real_l, data_complex_l));
else if constexpr(Direction == HIPFFT_C2R)
fftSafeCall(hipfftExecC2R(fftPlans->second[i], data_complex_l, data_real_l));
else if constexpr(Direction == HIPFFT_D2Z)
fftSafeCall(hipfftExecD2Z(fftPlans->second[i], data_real_l, data_complex_l));
else if constexpr(Direction == HIPFFT_Z2D)
fftSafeCall(hipfftExecZ2D(fftPlans->second[i], data_complex_l, data_real_l));
hipGraph_t my_graph;
HIC_CHECK(hipStreamEndCapture(stream, &my_graph));
hipGraphNode_t my_node;
HIC_CHECK(hipGraphAddChildGraphNode(&my_node, new_graph, nullptr, 0, my_graph));
}
hipGraphExec_t instance;
HIC_CHECK(hipGraphInstantiate(&instance, new_graph, NULL, NULL, 0));
HIC_CHECK(hipStreamDestroy(stream));
HIC_CHECK(hipGraphDestroy(new_graph));

graphCache.insert({kfield, instance});
ptrCache.insert({kfield, std::make_pair(data_real, data_complex)});
// this graph does not exist yet
auto plans = plan_all<Type, Direction>(kfield, loens, nfft, offsets);

// create a temporary stream
hipStream_t stream;
HIC_CHECK(hipStreamCreate(&stream));

for (auto &plan : plans) // set the streams
plan.set_stream(stream);

// now create the graph
hipGraph_t new_graph;
hipGraphCreate(&new_graph, 0);
for (auto &plan : plans) {
HIC_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
plan.exec(data_real, data_complex);
hipGraph_t my_graph;
HIC_CHECK(hipStreamEndCapture(stream, &my_graph));
hipGraphNode_t my_node;
HIC_CHECK(
hipGraphAddChildGraphNode(&my_node, new_graph, nullptr, 0, my_graph));
}
hipGraphExec_t instance;
HIC_CHECK(hipGraphInstantiate(&instance, new_graph, NULL, NULL, 0));
HIC_CHECK(hipStreamDestroy(stream));
HIC_CHECK(hipGraphDestroy(new_graph));

graphCache.insert({key, instance});
ptrCache.insert({key, std::make_pair(data_real, data_complex)});
}

HIC_CHECK(hipGraphLaunch(graphCache.at(kfield), 0));
HIC_CHECK(hipGraphLaunch(graphCache.at(key), 0));
HIC_CHECK(hipDeviceSynchronize());
}
} // namespace

template <class Type, hipfftType Direction>
void run_group(typename Type::real *data_real,
typename Type::cmplx *data_complex, int kfield, int *loens,
int *offsets, int nfft, void *growing_allocator) {
auto plans = plan_all<Type, Direction>(kfield, loens, nfft, offsets);

for (auto &plan : plans)
plan.exec(data_real, data_complex);
HIC_CHECK(hipDeviceSynchronize());
}
} // namespace

extern "C" {
#ifdef USE_GRAPHS_FFT
#define RUN run_group_graph
#else
#define RUN run_group
#endif
void execute_dir_fft_float(float *data_real, hipfftComplex *data_complex,
int kfield, int *loens, int *offsets, int nfft,
void *growing_allocator) {
execute_fft_new<Float, HIPFFT_R2C>(data_real, data_complex, kfield, loens, offsets,
nfft, growing_allocator);
int kfield, int *loens, int *offsets, int nfft,
void *growing_allocator) {
RUN<Float, HIPFFT_R2C>(data_real, data_complex, kfield, loens, offsets, nfft,
growing_allocator);
}
void execute_inv_fft_float(hipfftComplex *data_complex, float *data_real,
int kfield, int *loens, int *offsets, int nfft,
void *growing_allocator) {
execute_fft_new<Float, HIPFFT_C2R>(data_real, data_complex, kfield, loens, offsets,
nfft, growing_allocator);
int kfield, int *loens, int *offsets, int nfft,
void *growing_allocator) {
RUN<Float, HIPFFT_C2R>(data_real, data_complex, kfield, loens, offsets, nfft,
growing_allocator);
}
void execute_dir_fft_double(double *data_real, hipfftDoubleComplex *data_complex,
int kfield, int *loens, int *offsets, int nfft,
void *growing_allocator) {
execute_fft_new<Double, HIPFFT_D2Z>(data_real, data_complex, kfield, loens,
offsets, nfft, growing_allocator);
void execute_dir_fft_double(double *data_real,
hipfftDoubleComplex *data_complex, int kfield,
int *loens, int *offsets, int nfft,
void *growing_allocator) {
RUN<Double, HIPFFT_D2Z>(data_real, data_complex, kfield, loens, offsets, nfft,
growing_allocator);
}
void execute_inv_fft_double(hipfftDoubleComplex *data_complex, double *data_real,
int kfield, int *loens, int *offsets, int nfft,
void *growing_allocator) {
execute_fft_new<Double, HIPFFT_Z2D>(data_real, data_complex, kfield, loens,
offsets, nfft, growing_allocator);
void execute_inv_fft_double(hipfftDoubleComplex *data_complex,
double *data_real, int kfield, int *loens,
int *offsets, int nfft, void *growing_allocator) {
RUN<Double, HIPFFT_Z2D>(data_real, data_complex, kfield, loens, offsets, nfft,
growing_allocator);
}
#undef RUN
}

3 changes: 3 additions & 0 deletions src/trans/gpu/external/setup_trans.F90
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ SUBROUTINE SETUP_TRANS(KSMAX,KDGL,KDLON,KLOEN,LDSPLIT,PSTRET,&
#ifdef USE_GPU_AWARE_MPI
WRITE(NOUT,'(A)') " - GPU-aware MPI"
#endif
#ifdef USE_GRAPHS_FFT
WRITE(NOUT,'(A)') " - graph-based FFT scheduling"
#endif
#ifdef USE_GRAPHS_GEMM
WRITE(NOUT,'(A)') " - graph-based GEMM scheduling"
#endif
Expand Down

0 comments on commit f6b40d6

Please sign in to comment.