Skip to content

Commit

Permalink
Change torchao quantization types from int to size_t and preface vars…
Browse files Browse the repository at this point in the history
… with "preferred_"

Differential Revision: D63873383

Pull Request resolved: #1041
  • Loading branch information
keyan authored Oct 10, 2024
1 parent 0f6bae5 commit 76b6e36
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void kernel_impl(
// The groupi_zero is only present if has_weight_zeros = true.

// Returns number of bytes required for weight_data
int inline weight_data_size_impl(
size_t inline weight_data_size_impl(
int n,
int k,
int group_size,
Expand Down Expand Up @@ -270,7 +270,7 @@ void prepare_weight_data_impl(

// Activation functions
template <bool has_weight_zeros>
int torchao::kernels::cpu::aarch64::linear::
size_t torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot::
activation_data_size(int m, int k, int group_size) {
return torchao::kernels::cpu::aarch64::linear::
Expand All @@ -297,7 +297,7 @@ void torchao::kernels::cpu::aarch64::linear::

// Weight functions
template <int weight_nbit, bool has_weight_zeros>
int torchao::kernels::cpu::aarch64::linear::
size_t torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot::
weight_data_size(int n, int k, int group_size) {
return torchao::kernels::cpu::aarch64::linear::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void kernel_impl(
// Prepares weight data for kernel_impl.

// Returns number of bytes required for weight_data
int inline weight_data_size_impl(
size_t inline weight_data_size_impl(
int n,
int k,
int group_size,
Expand Down Expand Up @@ -397,7 +397,7 @@ void prepare_weight_data_impl(

// Activation functions
template <bool has_weight_zeros>
int torchao::kernels::cpu::aarch64::linear::
size_t torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot::
activation_data_size(int m, int k, int group_size) {
return torchao::kernels::cpu::aarch64::linear::
Expand All @@ -424,7 +424,7 @@ void torchao::kernels::cpu::aarch64::linear::

// Weight functions
template <int weight_nbit, bool has_weight_zeros>
int torchao::kernels::cpu::aarch64::linear::
size_t torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot::
weight_data_size(int n, int k, int group_size) {
return torchao::kernels::cpu::aarch64::linear::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ void kernel_impl(
// Prepares weight data for kernel_impl.

// Returns number of bytes required for weight_data
int inline weight_data_size_impl(
size_t inline weight_data_size_impl(
int n,
int k,
int group_size,
Expand Down Expand Up @@ -483,7 +483,7 @@ void prepare_weight_data_impl(

// Activation functions
template <bool has_weight_zeros>
int torchao::kernels::cpu::aarch64::linear::
size_t torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::
activation_data_size(int m, int k, int group_size) {
return torchao::kernels::cpu::aarch64::linear::
Expand All @@ -510,7 +510,7 @@ void torchao::kernels::cpu::aarch64::linear::

// Weight functions
template <int weight_nbit, bool has_weight_zeros>
int torchao::kernels::cpu::aarch64::linear::
size_t torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::
weight_data_size(int n, int k, int group_size) {
return torchao::kernels::cpu::aarch64::linear::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace channelwise_8bit_activation_prepare_activation_data_1xk_f32::
// The groupi_qvals_sum is only present if has_weight_zeros = true.

// Returns number of bytes required for activation_data
int inline activation_data_size_impl(
size_t inline activation_data_size_impl(
int m,
int k,
// Ignored if has_weight_zeros = false
Expand Down
13 changes: 7 additions & 6 deletions torchao/experimental/kernels/cpu/aarch64/linear/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
#if defined(__aarch64__) || defined(__ARM_NEON)

#include <arm_neon.h>
#include <stddef.h>

namespace torchao::kernels::cpu::aarch64::linear {

namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot {

template <bool has_weight_zeros>
int activation_data_size(int m, int k, int group_size);
size_t activation_data_size(int m, int k, int group_size);

template <bool has_weight_zeros>
void prepare_activation_data(
Expand All @@ -28,7 +29,7 @@ void prepare_activation_data(
const float* activations);

template <int weight_nbit, bool has_weight_zeros>
int weight_data_size(int n, int k, int group_size);
size_t weight_data_size(int n, int k, int group_size);

template <int weight_nbit, bool has_weight_zeros>
void prepare_weight_data(
Expand Down Expand Up @@ -65,7 +66,7 @@ void kernel(
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot {

template <bool has_weight_zeros>
int activation_data_size(int m, int k, int group_size);
size_t activation_data_size(int m, int k, int group_size);

template <bool has_weight_zeros>
void prepare_activation_data(
Expand All @@ -78,7 +79,7 @@ void prepare_activation_data(
const float* activations);

template <int weight_nbit, bool has_weight_zeros>
int weight_data_size(int n, int k, int group_size);
size_t weight_data_size(int n, int k, int group_size);

template <int weight_nbit, bool has_weight_zeros>
void prepare_weight_data(
Expand Down Expand Up @@ -115,7 +116,7 @@ void kernel(
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot {

template <bool has_weight_zeros>
int activation_data_size(int m, int k, int group_size);
size_t activation_data_size(int m, int k, int group_size);

template <bool has_weight_zeros>
void prepare_activation_data(
Expand All @@ -128,7 +129,7 @@ void prepare_activation_data(
const float* activations);

template <int weight_nbit, bool has_weight_zeros>
int weight_data_size(int n, int k, int group_size);
size_t weight_data_size(int n, int k, int group_size);

template <int weight_nbit, bool has_weight_zeros>
void prepare_weight_data(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ UKernelConfig get_ukernel_config() {
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.activation_data_alignment = 16; // size of neon register
config.preferred_activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.weight_data_alignment = 16; // size of neon register
config.preferred_weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
Expand Down Expand Up @@ -85,13 +85,13 @@ static void linear_8bit_act_xbit_weight(benchmark::State& state) {
// Pack test case weights
size_t packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
size_t packed_weight_data_alignment =
get_packed_weight_data_alignment(ukernel_config);
size_t preferred_packed_weight_data_alignment =
get_preferred_packed_weight_data_alignment(ukernel_config);

std::vector<std::unique_ptr<char[], void (*)(void*)>> packed_weight_data;
for (int i = 0; i < test_cases.size(); i++) {
packed_weight_data.emplace_back(torchao::make_aligned_byte_ptr(
packed_weight_data_alignment, packed_weight_data_size));
preferred_packed_weight_data_alignment, packed_weight_data_size));
pack_weight_data_operator(
ukernel_config,
pack_weight_data_tiling_params,
Expand All @@ -112,11 +112,11 @@ static void linear_8bit_act_xbit_weight(benchmark::State& state) {
m,
k,
group_size);
size_t activation_data_buffer_alignment =
get_activation_data_buffer_alignment(ukernel_config);
size_t preferred_activation_data_buffer_alignment =
get_preferred_activation_data_buffer_alignment(ukernel_config);

auto activation_data_buffer = torchao::make_aligned_byte_ptr(
activation_data_buffer_alignment, activation_data_buffer_size);
preferred_activation_data_buffer_alignment, activation_data_buffer_size);

auto output = std::vector<float>(m * n);
for (auto _ : state) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Linear8BitActXBitWeightOperator {
private:
torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr};
int packed_weight_data_size_{0};
int packed_weight_data_alignment_{0};
int preferred_packed_weight_data_alignment_{0};

torchao::aligned_byte_ptr activation_data_buffer_{nullptr, nullptr};

Expand Down Expand Up @@ -107,13 +107,13 @@ class Linear8BitActXBitWeightOperator {
// Pack weight data
auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_);
auto packed_weight_data_alignment =
get_packed_weight_data_alignment(ukernel_config_);
auto preferred_packed_weight_data_alignment =
get_preferred_packed_weight_data_alignment(ukernel_config_);

packed_weight_data_size_ = packed_weight_data_size;
packed_weight_data_alignment_ = packed_weight_data_alignment;
preferred_packed_weight_data_alignment_ = preferred_packed_weight_data_alignment;
packed_weight_data_ = torchao::make_aligned_byte_ptr(
packed_weight_data_alignment, packed_weight_data_size);
preferred_packed_weight_data_alignment, packed_weight_data_size);

pack_weight_data_operator(
ukernel_config_,
Expand All @@ -136,7 +136,7 @@ class Linear8BitActXBitWeightOperator {
k_,
group_size_);
auto activation_data_buffer_alignment =
get_activation_data_buffer_alignment(ukernel_config_);
get_preferred_activation_data_buffer_alignment(ukernel_config_);
activation_data_buffer_ = torchao::make_aligned_byte_ptr(
activation_data_buffer_alignment, activation_data_buffer_size);

Expand Down Expand Up @@ -168,7 +168,7 @@ class Linear8BitActXBitWeightOperator {
k_,
group_size_);
auto activation_data_buffer_alignment =
get_activation_data_buffer_alignment(ukernel_config_);
get_preferred_activation_data_buffer_alignment(ukernel_config_);
activation_data_buffer_ = torchao::make_aligned_byte_ptr(
activation_data_buffer_alignment, activation_data_buffer_size);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ UKernelConfig get_ukernel_config() {
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.activation_data_alignment = 16; // size of neon register
config.preferred_activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.weight_data_alignment = 16; // size of neon register
config.preferred_weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
Expand Down Expand Up @@ -67,10 +67,10 @@ torchao::aligned_byte_ptr pack_weight_data_operator(

auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
auto packed_weight_data_alignment =
get_packed_weight_data_alignment(ukernel_config);
auto preferred_packed_weight_data_alignment =
get_preferred_packed_weight_data_alignment(ukernel_config);
auto packed_weight_data = torchao::make_aligned_byte_ptr(
packed_weight_data_alignment, packed_weight_data_size);
preferred_packed_weight_data_alignment, packed_weight_data_size);

pack_weight_data_operator(
ukernel_config,
Expand Down Expand Up @@ -118,7 +118,7 @@ void linear_operator(
auto activation_data_buffer_size = get_activation_data_buffer_size(
ukernel_config, tiling_params_, scheduling_policy_, m, k, group_size);
auto activation_data_buffer_alignment =
get_activation_data_buffer_alignment(ukernel_config);
get_preferred_activation_data_buffer_alignment(ukernel_config);
auto activation_data_buffer = torchao::make_aligned_byte_ptr(
activation_data_buffer_alignment, activation_data_buffer_size);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ UKernelConfig get_ukernel_config() {
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.activation_data_alignment = 16; // size of neon register
config.preferred_activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.weight_data_alignment = 16; // size of neon register
config.preferred_weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ LinearTilingParams get_default_linear_tiling_params(

namespace internal {

inline int
inline size_t
get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc(
const UKernelConfig& ukernel_config,
const LinearTilingParams& tiling_params,
Expand All @@ -128,7 +128,7 @@ get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc(
tiling_params.mc_by_mr * ukernel_config.mr, k, group_size);
}

inline int
inline size_t
get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc(
const UKernelConfig& ukernel_config,
const LinearTilingParams& tiling_params,
Expand Down Expand Up @@ -162,7 +162,7 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr);
int num_mc_panels = (m + mc - 1) / mc;
int num_nc_panels = (n + nc - 1) / nc;
int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
size_t weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);

for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) {
int m_idx = mc_tile_idx * mc;
Expand Down Expand Up @@ -223,8 +223,8 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
int num_mc_panels = (m + mc - 1) / mc;
int num_nc_panels = (n + nc - 1) / nc;

int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
int activation_data_size =
size_t weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
size_t activation_data_size =
ukernel_config.activation_data_size_fn(mr, k, group_size);

torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) {
Expand Down Expand Up @@ -332,7 +332,7 @@ void linear_operator(
}
}

int get_activation_data_buffer_size(
size_t get_activation_data_buffer_size(
const UKernelConfig& ukernel_config,
const LinearTilingParams& tiling_params,
LinearTileSchedulingPolicy scheduling_policy,
Expand Down
Loading

0 comments on commit 76b6e36

Please sign in to comment.