Skip to content

Commit

Permalink
One bit kernels (#95)
Browse files Browse the repository at this point in the history
* two-bit template redesign

* 1-bit kernel

* 1-bit dequant

* black

* 1.1.6dev

* 1.1.6
  • Loading branch information
BlackSamorez authored May 31, 2024
1 parent 026b97d commit bf7d348
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 138 deletions.
2 changes: 1 addition & 1 deletion inference_lib/setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = aqlm
version = 1.1.5
version = 1.1.6
author = AQLM paper authors
author_email = [email protected]
description = Efficiently run models quantized with AQLM
Expand Down
183 changes: 137 additions & 46 deletions inference_lib/src/aqlm/inference_kernels/cuda_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,43 @@ inline bool check_use_bfloat16(const torch::Tensor& input) {
}
}

void code1x16_matvec_cuda(

template <bool use_bfloat16, size_t group_size>
void code1x16_matvec_cuda(
const void* A,
const void* B,
void* C,
const void* codebook,
int prob_m,
int prob_k,
bool use_bfloat16
int prob_k
);
extern template void code1x16_matvec_cuda<false, 8>(const void*, const void*, void*, const void*, int, int);
extern template void code1x16_matvec_cuda<true, 8>(const void*, const void*, void*, const void*, int, int);
extern template void code1x16_matvec_cuda<false, 16>(const void*, const void*, void*, const void*, int, int);
extern template void code1x16_matvec_cuda<true, 16>(const void*, const void*, void*, const void*, int, int);

template <size_t group_size>
void code1x16_dequant_cuda(
const void* A,
void* C,
const void* codebook,
int prob_m,
int prob_k
);
extern template void code1x16_dequant_cuda<8>(const void*, void*, const void*, int, int);
extern template void code1x16_dequant_cuda<16>(const void*, void*, const void*, int, int);

template <bool use_bfloat16>
void code2x8_matvec_cuda(
const void* A,
const void* B,
void* C,
const void* codebook,
int prob_m,
int prob_k,
bool use_bfloat16
int prob_k
);
extern template void code2x8_matvec_cuda<false>(const void*, const void*, void*, const void*, int, int);
extern template void code2x8_matvec_cuda<true>(const void*, const void*, void*, const void*, int, int);

void code2x8_dequant_cuda(
const void* A,
Expand Down Expand Up @@ -89,15 +99,29 @@ void code1x16_matvec(
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code1x16_matvec_cuda(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
use_bfloat16
);

if (codebook.size(3) == 8) {
if (use_bfloat16) {
code1x16_matvec_cuda<true, 8>(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k);
} else {
code1x16_matvec_cuda<false, 8>(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k);
}
} else if (codebook.size(3) == 16) {
if (use_bfloat16) {
code1x16_matvec_cuda<true, 16>(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k);
} else {
code1x16_matvec_cuda<false, 16>(A.data_ptr(), B.data_ptr(), C.data_ptr(), codebook.data_ptr(), prob_m, prob_k);
}
} else {
throw c10::NotImplementedError(
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
c10::str(
"AQLM CUDA kernels only support codebooks with 8 or 16 features. Got ",
codebook.size(3),
"."
)
);
}
}

torch::Tensor code1x16_matmat(
Expand Down Expand Up @@ -142,21 +166,40 @@ torch::Tensor code1x16_dequant(
const torch::Tensor& scales
) {
check_use_bfloat16(codebooks);
auto in_features = codes.size(1) * 8;
auto in_features = codes.size(1) * codebooks.size(3);
auto out_features = scales.size(0);

auto weight = torch::empty({out_features, in_features},
torch::TensorOptions()
.dtype(codebooks.dtype())
.device(codebooks.device())
);
code1x16_dequant_cuda(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
if (codebooks.size(3) == 8) {
code1x16_dequant_cuda<8>(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
} else if (codebooks.size(3) == 16) {
code1x16_dequant_cuda<16>(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
} else {
throw c10::NotImplementedError(
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
c10::str(
"AQLM CUDA kernels only support codebooks with 8 or 16 features. Got ",
codebooks.size(3),
"."
)
);
}
weight *= scales.index({"...", 0, 0});

return weight;
Expand Down Expand Up @@ -191,7 +234,7 @@ torch::Tensor code1x16_matmat_dequant(
) {
bool use_bfloat16 = check_use_bfloat16(input);
auto input_sizes = input.sizes();
auto in_features = codes.size(1) * 8;
auto in_features = codes.size(1) * codebooks.size(3);
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});

Expand All @@ -200,13 +243,32 @@ torch::Tensor code1x16_matmat_dequant(
.dtype(codebooks.dtype())
.device(codebooks.device())
);
code1x16_dequant_cuda(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
if (codebooks.size(3) == 8) {
code1x16_dequant_cuda<8>(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
} else if (codebooks.size(3) == 16) {
code1x16_dequant_cuda<16>(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
} else {
throw c10::NotImplementedError(
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
c10::str(
"AQLM CUDA kernels only support codebooks with 8 or 16 features. Got ",
codebooks.size(3),
"."
)
);
}

auto flat_output = F::linear(flat_input, weight);
return scale_bias_unflatten_output(
Expand Down Expand Up @@ -235,13 +297,32 @@ torch::Tensor code1x16_matmat_dequant_transposed(
.dtype(codebooks.dtype())
.device(codebooks.device())
);
code1x16_dequant_cuda(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
if (codebooks.size(3) == 8) {
code1x16_dequant_cuda<8>(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
} else if (codebooks.size(3) == 16) {
code1x16_dequant_cuda<16>(
codes.data_ptr(),
weight.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features
);
} else {
throw c10::NotImplementedError(
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
c10::str(
"AQLM CUDA kernels only support codebooks with 8 or 16 features. Got ",
codebooks.size(3),
"."
)
);
}

torch::Tensor bias_2{};
if (bias.has_value()) {
Expand All @@ -261,15 +342,25 @@ void code2x8_matvec(
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code2x8_matvec_cuda(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
use_bfloat16
);
if (use_bfloat16) {
code2x8_matvec_cuda<true>(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k
);
} else {
code2x8_matvec_cuda<false>(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k
);
}
}

torch::Tensor code2x8_matmat(
Expand Down
Loading

0 comments on commit bf7d348

Please sign in to comment.