Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cuda arch support for DeepEP #71481

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/collective/deep_ep/kernels/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,45 +53,61 @@ __device__ __forceinline__ void trap() {
}

__device__ __forceinline__ void memory_fence() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.acq_rel.sys;":: : "memory");
#endif
}

__device__ __forceinline__ void memory_fence_gpu() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.acq_rel.gpu;":: : "memory");
#endif
}

__device__ __forceinline__ void memory_fence_cta() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.acq_rel.cta;":: : "memory");
#endif
}

__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
#endif
}

__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
#endif
}

__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}

__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
int ret;
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
#endif
}

__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint64_t ret;
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
#endif
}

__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
int ret;
asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
#endif
}

__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) {
Expand All @@ -113,27 +129,35 @@ __device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
}

__device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
return static_cast<uint8_t>(ret);
#endif
}

__device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
return ret;
#endif
}

__device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
#endif
}

__device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint64_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
#endif
}

__device__ __forceinline__ int ld_volatile_global(const int *ptr) {
Expand All @@ -160,6 +184,11 @@ __device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
return ret;
}

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
#else
#define DISABLE_AGGRESSIVE_PTX_INSTRS
#endif

#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
#else
Expand Down Expand Up @@ -220,36 +249,52 @@ __device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) {
}

__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val)));
#endif
}

__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val));
#endif
}

__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
#endif
}

__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
#endif
}

__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};"
: : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
#endif
}

__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
#endif
}

__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
#endif
}

__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
#endif
}

// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS,
Expand Down Expand Up @@ -385,6 +430,7 @@ timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag =
template <int kNumRanks>
__forceinline__ __device__ void
barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
auto thread_id = static_cast<int>(threadIdx.x);
EP_DEVICE_ASSERT(kNumRanks <= 32);

Expand All @@ -394,6 +440,7 @@ barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG);
}
timeout_check<kNumRanks>(task_fifo_ptrs, head, rank, 0, tag);
#endif
}

} // namespace deep_ep
Loading