Skip to content

Commit

Permalink
sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdabcd987 committed Apr 28, 2024
1 parent a9b4fe7 commit 3afb7d9
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 12 deletions.
10 changes: 0 additions & 10 deletions csrc/flashinfer_adapter/flashinfer_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "flashinfer/decode_attention_decl.cuh"
#include "flashinfer/page.cuh"
#include "flashinfer/prefill_attention_decl.cuh"
#include "flashinfer_config.h"
#include "generated/dispatch.inc"

#define _DISPATCH_SWITCH(cond, ...) \
Expand All @@ -35,15 +34,6 @@
#define DISPATCH_head_dim(expr, ...) \
_DISPATCH_SWITCH(expr, _DISPATCH_CASES_head_dim(__VA_ARGS__))

namespace {
template <typename T>
inline T* alloc_from_buf(void** buf, int n) {
auto* p = (T*)*buf;
*buf = (void*)(p + n);
return p;
}
} // namespace

template <typename T>
bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs,
int32_t* kv_indptr, int32_t* last_page_offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,17 @@ bool FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int head_dim, int num_kv_heads, int page_size,
int batch_size);

template <typename T>
cudaError_t FlashInferSampleProb(T* probs, T* uniform_samples, int32_t* output,
int batch_size, int dim);

template <typename T>
cudaError_t FlashInferSampleTopK(T* probs, T* uniform_samples, int32_t* output,
int batch_size, int dim, int32_t k,
int max_rounds);

template <typename T>
cudaError_t FlashInferSampleTopP(T* probs, T* uniform_samples, int32_t* output,
int batch_size, int dim, float p,
int max_rounds);
55 changes: 55 additions & 0 deletions csrc/flashinfer_adapter/flashinfer_sampling.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cstdint>

#include "flashinfer/sampling.cuh"

template <typename T>
cudaError_t FlashInferSampleProb(T* probs, T* uniform_samples, int32_t* output,
int batch_size, int dim) {
return flashinfer::sampling::SamplingFromProb<T, int32_t>(
probs, uniform_samples, output, batch_size, dim);
}
#define INST_FlashInferSampleProb(T) \
template cudaError_t FlashInferSampleProb<T>(T * probs, T * uniform_samples, \
int32_t * output, \
int batch_size, int dim);
INST_FlashInferSampleProb(float);
INST_FlashInferSampleProb(nv_half);
INST_FlashInferSampleProb(nv_bfloat16);

template <typename T>
cudaError_t FlashInferSampleTopK(T* probs, T* uniform_samples, int32_t* output,
int batch_size, int dim, int32_t k,
int max_rounds) {
bool* success = nullptr; // Skip checking.
auto status = flashinfer::sampling::TopKSamplingFromProb<T, int32_t>(
probs, uniform_samples, output, success, k, batch_size, dim, max_rounds);
return status;
}
#define INST_FlashInferSampleTopK(T) \
template cudaError_t FlashInferSampleTopK<T>( \
T * probs, T * uniform_samples, int32_t * output, int batch_size, \
int dim, int32_t k, int max_rounds);
INST_FlashInferSampleTopK(float);
INST_FlashInferSampleTopK(nv_half);
INST_FlashInferSampleTopK(nv_bfloat16);

template <typename T>
cudaError_t FlashInferSampleTopP(T* probs, T* uniform_samples, int32_t* output,
int batch_size, int dim, float p,
int max_rounds) {
bool* success = nullptr; // Skip checking.
auto status = flashinfer::sampling::TopPSamplingFromProb<T, int32_t>(
probs, uniform_samples, output, success, static_cast<T>(p), batch_size,
dim, max_rounds);
return status;
}
#define INST_FlashInferSampleTopP(T) \
template cudaError_t FlashInferSampleTopP<T>( \
T * probs, T * uniform_samples, int32_t * output, int batch_size, \
int dim, float p, int max_rounds);
INST_FlashInferSampleTopP(float);
INST_FlashInferSampleTopP(nv_half);
INST_FlashInferSampleTopP(nv_bfloat16);
123 changes: 122 additions & 1 deletion csrc/punica_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <cstdint>

#include "bgmv/bgmv_config.h"
#include "flashinfer_adapter/flashinfer_config.h"
#include "flashinfer_adapter/flashinfer_decl.h"
#include "rms_norm/rms_norm.h"
#include "sgmv/sgmv.h"
#include "sgmv_flashinfer/sgmv_config.h"
Expand Down Expand Up @@ -80,6 +80,14 @@ torch::Tensor GetLayerKvPtrs(torch::Tensor kv_ptrs, int num_layers,
#define DISPATCH_TORCH_DTYPE(scalar_type, ...) \
_DISPATCH_SWITCH(scalar_type, _DISPATCH_DTYPE_CASES(__VA_ARGS__))

#define _DISPATCH_DTYPE_WITH_FP32_CASES(...) \
_DISPATCH_DTYPE_CASE(at::ScalarType::Float, float, __VA_ARGS__) \
_DISPATCH_DTYPE_CASE(at::ScalarType::Half, nv_half, __VA_ARGS__) \
_DISPATCH_DTYPE_CASE(at::ScalarType::BFloat16, nv_bfloat16, __VA_ARGS__)

#define DISPATCH_TORCH_DTYPE_WITH_FP32(scalar_type, ...) \
_DISPATCH_SWITCH(scalar_type, _DISPATCH_DTYPE_WITH_FP32_CASES(__VA_ARGS__))

//====== flashinfer ======

void batch_prefill(torch::Tensor o, torch::Tensor q, torch::Tensor qo_indptr,
Expand Down Expand Up @@ -422,6 +430,115 @@ void dispatch_rms_norm(torch::Tensor output, torch::Tensor input,
" columns=", columns);
}

//====== sampling ======

int GetBatchSize(torch::Tensor x) {
auto shape = x.sizes();
switch (x.dim()) {
case 0:
return 0;
case 1:
return 1;
case 2:
return shape[0];
case 3:
return shape[0] * shape[1];
case 4:
return shape[0] * shape[1] * shape[2];
default: {
int batch_size = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
batch_size *= shape[i];
}
return batch_size;
}
}
}

void sample_prob(torch::Tensor probs, torch::Tensor uniform_samples,
torch::Tensor output) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(output);

int batch_size = GetBatchSize(probs);
int dim = probs.size(-1);
int max_rounds = uniform_samples.numel() / batch_size;
CHECK_GE(max_rounds, 32);
CHECK_EQ(output.numel(), batch_size);
CHECK_EQ(probs.scalar_type(), uniform_samples.scalar_type());

cudaError_t status;
bool dispatched = DISPATCH_TORCH_DTYPE_WITH_FP32(probs.scalar_type(), [&] {
status =
FlashInferSampleProb(static_cast<c_type*>(probs.data_ptr()),
static_cast<c_type*>(uniform_samples.data_ptr()),
output.data_ptr<int32_t>(), batch_size, dim);
return true;
});

TORCH_CHECK(dispatched, "No suitable kernel.",
" dtype=", probs.scalar_type());
TORCH_CHECK(status == cudaSuccess,
"Error in sample_prob. status=", cudaGetErrorString(status));
}

void sample_topk(torch::Tensor probs, torch::Tensor uniform_samples,
torch::Tensor output, int k) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(output);

int batch_size = GetBatchSize(probs);
int dim = probs.size(-1);
int max_rounds = uniform_samples.numel() / batch_size;
CHECK_GE(max_rounds, 32);
CHECK_EQ(output.numel(), batch_size);
CHECK_EQ(probs.scalar_type(), uniform_samples.scalar_type());

cudaError_t status;
bool dispatched = DISPATCH_TORCH_DTYPE_WITH_FP32(probs.scalar_type(), [&] {
status = FlashInferSampleTopK(
static_cast<c_type*>(probs.data_ptr()),
static_cast<c_type*>(uniform_samples.data_ptr()),
output.data_ptr<int32_t>(), batch_size, dim, k, max_rounds);
return true;
});

TORCH_CHECK(dispatched, "No suitable kernel.",
" dtype=", probs.scalar_type());
TORCH_CHECK(status == cudaSuccess,
"Error in sample_topk. status=", cudaGetErrorString(status));
}

void sample_topp(torch::Tensor probs, torch::Tensor uniform_samples,
torch::Tensor output, float p) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(output);

int batch_size = GetBatchSize(probs);
int dim = probs.size(-1);
int max_rounds = uniform_samples.numel() / batch_size;
CHECK_GE(max_rounds, 32);
CHECK_EQ(output.numel(), batch_size);
CHECK_EQ(probs.scalar_type(), uniform_samples.scalar_type());

cudaError_t status;
bool dispatched = DISPATCH_TORCH_DTYPE_WITH_FP32(probs.scalar_type(), [&] {
status = FlashInferSampleTopP(
static_cast<c_type*>(probs.data_ptr()),
static_cast<c_type*>(uniform_samples.data_ptr()),
output.data_ptr<int32_t>(), batch_size, dim, p, max_rounds);
return true;
});

TORCH_CHECK(dispatched, "No suitable kernel.",
" dtype=", probs.scalar_type());
TORCH_CHECK(status == cudaSuccess,
"Error in sample_topp. status=", cudaGetErrorString(status));
}

} // namespace

//====== pybind ======
Expand All @@ -438,4 +555,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sgmv_cutlass_tmp_size", &sgmv_tmp_size, "");
m.def("sgmv_shrink", &dispatch_sgmv_shrink, "");
m.def("rms_norm", &dispatch_rms_norm, "");

m.def("sample_prob", &sample_prob, "");
m.def("sample_topk", &sample_topk, "");
m.def("sample_topp", &sample_topp, "");
}
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def generate_build_meta() -> None:
"csrc/punica_ops.cc",
"csrc/bgmv/bgmv_all.cu",
"csrc/flashinfer_adapter/flashinfer_all.cu",
"csrc/flashinfer_adapter/flashinfer_sampling.cu",
"csrc/rms_norm/rms_norm_cutlass.cu",
"csrc/sgmv/sgmv_cutlass.cu",
"csrc/sgmv_flashinfer/sgmv_all.cu",
Expand Down
61 changes: 61 additions & 0 deletions src/punica/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,64 @@ def rms_norm(
o = torch.empty_like(x)
_kernels.rms_norm(o, x, w, eps)
return o


def sample_prob(
prob: torch.Tensor,
uniform_samples: torch.Tensor,
) -> torch.Tensor:
"""
Sample from a probability distribution.
Args:
prob: Shape: `[...B, d]`. Probability distribution.
uniform_samples: Uniform samples. At least B*32 elements.
Returns:
Shape: `[...B]`. DType: int32. Sampled indices.
"""
o = torch.empty(prob.shape[:-1], dtype=torch.int32, device=prob.device)
_kernels.sample_prob(prob, uniform_samples, o)
return o


def sample_topk(
prob: torch.Tensor,
uniform_samples: torch.Tensor,
k: int,
) -> torch.Tensor:
"""
Sample from the top-k elements in a probability distribution.
Args:
prob: Shape: `[...B, d]`. Probability distribution.
uniform_samples: Uniform samples. At least B*32 elements.
k: Number of top-probability elements to sample from.
Returns:
Shape: `[...B]`. DType: int32. Sampled indices.
"""
o = torch.empty(prob.shape[:-1], dtype=torch.int32, device=prob.device)
_kernels.sample_topk(prob, uniform_samples, o, k)
return o


def sample_topp(
prob: torch.Tensor,
uniform_samples: torch.Tensor,
p: float,
) -> torch.Tensor:
"""
Sample from the top-p elements in a probability distribution.
Args:
prob: Shape: `[...B, d]`. Probability distribution.
uniform_samples: Uniform samples. At least B*32 elements.
p: Cumulative probability threshold.
Returns:
Shape: `[...B]`. DType: int32. Sampled indices.
"""
o = torch.empty(prob.shape[:-1], dtype=torch.int32, device=prob.device)
_kernels.sample_topp(prob, uniform_samples, o, p)
return o
Loading

0 comments on commit 3afb7d9

Please sign in to comment.