From 12557d8dfa345be39c8e913bcc4181d6fac9b27a Mon Sep 17 00:00:00 2001 From: "y01000.you" Date: Thu, 31 Oct 2024 17:17:49 +0900 Subject: [PATCH 1/6] [luci-pass] Introduce GPTQ pass This pr introduces quantize weight with GPTQ. ONE-DCO-1.0-Signed-off-by: Banseok Lee --- .../luci/Pass/QuantizeWeightsWithGPTQPass.h | 82 +++ .../pass/src/QuantizeWeightsWithGPTQPass.cpp | 674 ++++++++++++++++++ .../src/QuantizeWeightsWithGPTQPass.test.cpp | 206 ++++++ 3 files changed, 962 insertions(+) create mode 100644 compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h create mode 100644 compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp create mode 100644 compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h new file mode 100644 index 00000000000..afef945bf98 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_WEIGHTS_WITH_GPTQ_PASS_H__ +#define __LUCI_QUANTIZE_WEIGHTS_WITH_GPTQ_PASS_H__ + +#include + +#include + +#include +#include +#include + +namespace luci +{ + +/** + * @brief Pass to quantize weights + */ +class QuantizeWeightsWithGPTQPass : public logo::Pass +{ +public: + struct Context + { + loco::DataType input_model_dtype = loco::DataType::Unknown; + loco::DataType output_model_dtype = loco::DataType::Unknown; + QuantizationGranularity granularity = QuantizationGranularity::ChannelWise; + std::vector layers_info; + }; + +public: + QuantizeWeightsWithGPTQPass(std::unique_ptr &&ctx) : _ctx{std::move(ctx)} + { + // DO NOTHING + } + + QuantizeWeightsWithGPTQPass( + std::unique_ptr &&ctx, + std::unordered_map> *hessian_map) + : _ctx{std::move(ctx)}, _hessian_map{hessian_map} + { + // DO NOTHING + } + +public: + QuantizeWeightsWithGPTQPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype, + QuantizationGranularity granularity) + { + _ctx = std::make_unique(); + { + _ctx->input_model_dtype = input_model_dtype; + _ctx->output_model_dtype = output_model_dtype; + _ctx->granularity = granularity; + } + } + virtual const char *name(void) const { return "luci::QuantizeWeightsWithGPTQPass"; } + +public: + bool run(loco::Graph *graph); + +private: + std::unique_ptr _ctx; + std::unordered_map> *_hessian_map; +}; + +} // namespace luci + +#endif //__LUCI_QUANTIZE_WEIGHTS_WITH_GPTQ_PASS_H__ diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp new file mode 100644 index 00000000000..cd9267117a0 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp @@ -0,0 +1,674 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizeWeightsWithGPTQPass.h" +#include "QuantizationUtils.h" +#include "helpers/LayerInfoMap.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace +{ + +using namespace luci; +using IterFunc = std::function; + +void iterate_per_channel_with_order(CircleConst *node, IterFunc func, bool reverse) +{ + loco::TensorShape dimension; + dimension.rank(4); + uint32_t indices[4] = {0}; + int32_t channel_dim_index{0}; + uint32_t num_dims[4]; + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + throw std::runtime_error("Failed to get channel dim index."); + } + + auto order = reverse ? std::vector{3, 1, 2, 0} : std::vector{0, 1, 2, 3}; + + for (uint32_t i = 0; i < 4; ++i) + { + num_dims[i] = dimension.dim(order[i]).value(); + } + + for (uint32_t i = 0; i < num_dims[0]; i++) + { + for (uint32_t j = 0; j < num_dims[1]; j++) + { + for (uint32_t s = 0; s < num_dims[2]; s++) + { + for (uint32_t t = 0; t < num_dims[3]; t++) + { + indices[order[0]] = i; + indices[order[1]] = j; + indices[order[2]] = s; + indices[order[3]] = t; + func(indices, dimension, channel_dim_index); + } + } + } + } +} + +} // namespace + +namespace luci +{ + +namespace +{ + +size_t calculate_qauntized_value(CircleConst *node, uint32_t *indices, loco::TensorShape &dimension, + int channel_dim_index, std::vector &scaling_factor, + std::vector &max, std::vector &min) +{ + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at(cal_offset(dimension, indices)); + auto data_clipped = data < min[channel_idx] ? min[channel_idx] : data; + data_clipped = data_clipped > max[channel_idx] ? max[channel_idx] : data_clipped; + + return static_cast(std::round((data_clipped - min[channel_idx]) * scaling_factor_inv)); +} + +void cholesky_decomposition(std::vector &src, uint32_t num_size) +{ + for (uint32_t i = 0; i < num_size; i++) + { + for (uint32_t j = 0; j <= i; j++) + { + double sum = 0; + for (uint32_t k = 0; k < j; k++) + { + sum += src[i * num_size + k] * src[j * num_size + k]; + } + if (i == j) + { + if (src[i * num_size + i] - sum <= 0) + { + std::cout << "Error: Matrix is not positive definite.\n" << std::endl; + return; + } + src[i * num_size + i] = sqrt(src[i * num_size + i] - sum); + } + else + { + src[i * num_size + j] = (src[i * num_size + j] - sum) / src[j * num_size + j]; + } + } + } + for (uint32_t i = 0; i < num_size; i++) + { + for (uint32_t j = 0; j < num_size; j++) + { + if (i < j) + { + src[i * num_size + j] = 0.0; + } + } + } + return; +} + +void forward_substitution(const std::vector &L, const std::vector &b, + std::vector &y, int num_size) +{ + for (int i = 0; i < num_size; ++i) + { + y[i] = b[i]; + for (int j = 0; j < i; ++j) + { + y[i] -= L[i * num_size + j] * y[j]; + } + y[i] /= L[i * num_size + i]; + } +} + +void backward_substitution(const std::vector &U, const std::vector &y, + std::vector &x, int num_size) +{ + for (int i = num_size - 1; i >= 0; --i) + { + x[i] = y[i]; + for (int j = i + 1; j < num_size; ++j) + { + x[i] -= U[i * num_size + j] * x[j]; + } + x[i] /= U[i * num_size + i]; + } +} + +void cholesky_inverse(std::vector &L, uint32_t num_size) +{ + std::vector L_inv(L.size()); + std::vector H_inv(L.size()); + + std::vector e(num_size, 0); + std::vector col(num_size, 0); + std::vector temp(num_size, 0); + + for (uint32_t i = 0; i < num_size; ++i) + { + fill(e.begin(), e.end(), 0.0); + e[i] = 1.0; + + forward_substitution(L, e, temp, num_size); + + for (uint32_t j = 0; j < num_size; ++j) + { + L_inv[j * num_size + i] = temp[j]; + } + } + + for (uint32_t i = 0; i < num_size; i++) + { + for (uint32_t j = 0; j < i; j++) + { + float tmp = L[i * num_size + j]; + L[i * num_size + j] = L[j * num_size + i]; + L[j * num_size + i] = tmp; + } + } + + for (uint32_t i = 0; i < num_size; ++i) + { + fill(e.begin(), e.end(), 0.0); + fill(col.begin(), col.end(), 0.0); + e[i] = 1.0; + for (uint32_t j = 0; j < num_size; j++) + { + col[j] = L_inv[j * num_size + i]; + } + backward_substitution(L, col, temp, num_size); + for (uint32_t j = 0; j < num_size; ++j) + { + H_inv[j * num_size + i] = temp[j]; + } + } + for (uint32_t i = 0; i < L.size(); i++) + { + L[i] = H_inv[i]; + } +} + +void cal_minmax_per_channel(CircleConst *node, std::vector &min, std::vector &max) +{ + loco::TensorShape dimension; + dimension.rank(4); + int32_t channel_dim_index{0}; + + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + assert(false); + return; + } + auto size = dimension.dim(channel_dim_index).value(); + + std::vector has_min_max_value(size, false); + min.resize(size); + max.resize(size); + + auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + auto data = node->at(cal_offset(dimension, indices)); + if (has_min_max_value[channel_idx]) + { + min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx]; + max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx]; + } + else + { + min[channel_idx] = data; + max[channel_idx] = data; + has_min_max_value[channel_idx] = true; + } + }; + + iterate_per_channel_with_order(node, cal_minmax, false); +} + +void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, + float &nudged_min, float &nudged_max, int32_t k_max_scale) +{ + LOGGER(l); + + assert(min <= max); + const int32_t kMinScale = 0; + const int32_t kMaxScale = k_max_scale; + const double qmin_double = kMinScale; + const double qmax_double = kMaxScale; + const double rmin = std::fmin(0, min); + const double rmax = std::fmax(0, max); + + double scale = (rmax - rmin) / (qmax_double - qmin_double); + double zero_point_double = 0; + uint8_t nudged_zero_point = 0; + if (scale == 0) + { + WARN(l) << "The minimum and maximum values are the same." << std::endl; + if (min >= 0 && max >= 0) + zero_point_double = kMinScale; + else + zero_point_double = kMaxScale; + } + else + zero_point_double = qmin_double - rmin / scale; + if (min >= 0) + { + assert(min >= 0 && max >= 0); + nudged_zero_point = kMinScale; + scale = max / (qmax_double - qmin_double); + if (min > 0 && max > 0) + WARN(l) << "The minimum and maximum values are all positive." << std::endl; + } + else if (max < 0) + { + assert(min < 0 && max < 0); + nudged_zero_point = kMaxScale; + scale = -min / (qmax_double - qmin_double); + WARN(l) << "The minimum and maximum values are all negative." << std::endl; + } + else + { + assert(min < 0 && max >= 0); + nudged_zero_point = fp32_to_uint8_cast(std::round(zero_point_double)); + } + + // protect scale from being very low due to overflow + if (scale < 1e-5) + { + scale = 1e-5; + nudged_zero_point = fp32_to_uint8_cast(std::round(qmin_double - rmin / scale)); + } + + nudged_min = static_cast((qmin_double - nudged_zero_point) * scale); + nudged_max = static_cast((qmax_double - nudged_zero_point) * scale); + + scaling_factor = scale; + zp = nudged_zero_point; +} + +void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, + std::vector &max, std::vector &scaling_factor, + std::vector &zp, std::vector &nudged_min, + std::vector &nudged_max, loco::DataType output_type, + std::vector &hessian) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + + IterFunc quantize; + + const int32_t kMinScale = 0; + const int32_t kMaxScale = output_type == loco::DataType::U4 ? 15 : 255; + + uint32_t size = node->size(); + std::vector quantized_values(size); + + for (size_t i = 0; i < min.size(); ++i) + { + compute_asym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i], + kMaxScale); + } + + if (hessian.empty()) // Cases where gptq is not applied + { + quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + quantized_values[cal_offset(dimension, indices)] = calculate_qauntized_value( + node, indices, dimension, channel_dim_index, scaling_factor, nudged_max, nudged_min); + }; + iterate_per_channel_with_order(node, quantize, false); + } + else // Cases where gptq is applied + { + uint32_t size_hessian = static_cast(sqrt(hessian.size())); + float percdamp = .01; + float damp = 0; + + for (uint32_t i = 0; i < size_hessian; i++) + { + damp += hessian[i * size_hessian + i]; + } + damp /= size_hessian; + damp *= percdamp; + + for (uint32_t i = 0; i < size_hessian; i++) + { + hessian[i * size_hessian + i] += damp; + } + + // calculate hessian inverse + cholesky_decomposition(hessian, size_hessian); + cholesky_inverse(hessian, size_hessian); + cholesky_decomposition(hessian, size_hessian); + + // transpose hessian to make upper trangular + for (uint32_t i = 0; i < size_hessian; i++) + { + for (uint32_t j = 0; j < i; j++) + { + float tmp = hessian[i * size_hessian + j]; + hessian[i * size_hessian + j] = hessian[j * size_hessian + i]; + hessian[j * size_hessian + i] = tmp; + } + } + + std::vector error(size); + + loco::TensorShape dimension_channel_last; + dimension_channel_last.rank(4); + + loco::TensorShape dimension_hessian; + dimension_hessian.rank(2); + dimension_hessian.dim(0).set(size_hessian); + dimension_hessian.dim(1).set(size_hessian); + + quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + quantized_values[cal_offset(dimension, indices)] = calculate_qauntized_value( + node, indices, dimension, channel_dim_index, scaling_factor, nudged_max, nudged_min); + + uint32_t indices_channel_last[4] = { + indices[0], indices[3], indices[1], indices[2] // ohwi -> oihw + }; + + uint32_t dimension_channel_last[4] = {dimension.dim(0).value(), dimension.dim(3).value(), + dimension.dim(1).value(), dimension.dim(2).value()}; + + uint32_t idx_quant_column = + dimension_channel_last[2] * dimension_channel_last[3] * indices_channel_last[1] + + dimension_channel_last[3] * indices_channel_last[2] + indices_channel_last[3]; + + uint32_t indices_diag_hessian[2] = {idx_quant_column, idx_quant_column}; + + uint32_t channel_idx = indices[channel_dim_index]; + auto data = node->at(cal_offset(dimension, indices)); + + error[cal_offset(dimension, indices)] = + (data - (quantized_values[cal_offset(dimension, indices)] - zp[channel_idx]) * + scaling_factor[channel_idx]) / + hessian[cal_offset_2d(dimension_hessian, indices_diag_hessian)]; + + if (channel_idx == (dimension.dim(channel_dim_index).value() - 1)) + { + for (uint32_t o = 0; o < dimension_channel_last[0]; o++) + { + for (uint32_t i = 0; i < dimension_channel_last[1]; i++) + { + for (uint32_t h = 0; h < dimension_channel_last[2]; h++) + { + for (uint32_t w = 0; w < dimension_channel_last[3]; w++) + { + // convert coordination + uint32_t indices_channel_first[4] = {o, h, w, i}; + uint32_t indices_error[4] = {o, indices[1], indices[2], indices[3]}; + uint32_t idx_ihw = dimension_channel_last[2] * dimension_channel_last[3] * i + + dimension_channel_last[3] * h + w; + uint32_t indices_hessain[2] = {idx_quant_column, idx_ihw}; + + node->at(cal_offset(dimension, indices_channel_first)) -= + error[cal_offset(dimension, indices_error)] * + hessian[cal_offset_2d(dimension_hessian, indices_hessain)]; + } + } + } + } + } + }; + iterate_per_channel_with_order(node, quantize, true); + } + + node->dtype(loco::DataType::U8); // change the type of tensor + node->size(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } +} + +void asymmetric_wdequant_per_channel(CircleConst *node, std::vector &scaling_factor, + std::vector &nudged_min) +{ + assert(node->dtype() == loco::DataType::U8); + uint32_t size = node->size(); + std::vector dequantized_values(size); + + auto dequantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + auto data = node->at(cal_offset(dimension, indices)); + dequantized_values[cal_offset(dimension, indices)] = + static_cast(data) * scaling_factor[channel_idx] + nudged_min[channel_idx]; + }; + + iterate_per_channel_with_order(node, dequantize, false); + + node->dtype(loco::DataType::FLOAT32); // change the type of tensor + node->size(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at(i) = dequantized_values[i]; + } +} + +/** + * @brief QuantizeDequantizeWeights quantizes and dequantizes tensors for weights + * @details Find min/max values on the fly, quantize the model, and dequantize the model + */ +struct QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitor +{ + QuantizeWeightsWithGPTQ( + loco::DataType input, loco::DataType output, QuantizationGranularity granularity, + std::unordered_map> *hessian_map) + : input_type(input), output_type(output), granularity(granularity), hessian_map(hessian_map) + { + } + + loco::DataType input_type; + loco::DataType output_type; + QuantizationGranularity granularity; + std::unordered_map> *hessian_map; + +private: + void fake_quantize_cwq(luci::CircleConst *weights, std::vector &hessian) const + { + // assert(output_type == loco::DataType::U8); // FIX_CALLER_UNLESS + if (output_type != loco::DataType::U8) + { + throw std::runtime_error("GPTQ quantization supports u8"); + } + // Find min/max per channel + std::vector min; + std::vector max; + + cal_minmax_per_channel(weights, min, max); + + std::vector nudged_min(min.size()); + std::vector nudged_max(min.size()); + std::vector scaling_factor(min.size()); + std::vector zp(min.size()); + + asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max, + output_type, hessian); + asymmetric_wdequant_per_channel(weights, scaling_factor, nudged_min); + + auto quantparam = std::make_unique(); + quantparam->min = nudged_min; + quantparam->max = nudged_max; + quantparam->scale = scaling_factor; + quantparam->zerop = zp; + + weights->quantparam(std::move(quantparam)); + } + + void fake_quantize(luci::CircleConst *weights, std::vector &hessian) const + { + switch (granularity) + { + case luci::QuantizationGranularity::ChannelWise: + fake_quantize_cwq(weights, hessian); + break; + default: + throw std::invalid_argument("Unsupported granularity"); + } + } + +private: + // Check if + // 1. node is const + // 2. node's dtype is float32 + bool is_quantizable(loco::Node *node) + { + auto const_node = dynamic_cast(node); + if (not const_node) + return false; + + // Skip if this is not float32 + if (const_node->dtype() != loco::DataType::FLOAT32) + return false; + + return true; + } + + // Default behavior (Do nothing) + void visit(luci::CircleNode *) {} + + void visit(luci::CircleConv2D *node) + { + LOGGER(l); + INFO(l) << "QuantizeWeightsWithGPTQPass visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + + auto hessian = (*hessian_map)[node]; + + fake_quantize(new_weights, hessian); + } + + void visit(luci::CircleDepthwiseConv2D *node) + { + LOGGER(l); + INFO(l) << "QuantizeWeightsWithGPTQPass visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + + std::vector empty_vector; + + fake_quantize(new_weights, empty_vector); + } + + void visit(luci::CircleTransposeConv *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + + std::vector empty_vector; + + fake_quantize(new_weights, empty_vector); + } + + void visit(luci::CircleFullyConnected *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + if (not is_quantizable(node->weights())) + return; + + auto weights = loco::must_cast(node->weights()); + auto new_weights = luci::clone(weights); + node->weights(new_weights); + + auto hessian = (*hessian_map)[node]; + + fake_quantize(new_weights, hessian); + } +}; + +} // namespace + +bool QuantizeWeightsWithGPTQPass::run(loco::Graph *g) +{ + LOGGER(l); + INFO(l) << "QuantizeWeightsWithGPTQPass Start" << std::endl; + + if (_ctx->input_model_dtype != loco::DataType::FLOAT32) + throw std::runtime_error("Weights-only quantization supports float32 input only"); + + if (_ctx->output_model_dtype != loco::DataType::U8) + throw std::runtime_error("GPTQ quantization supports uint8 output only"); + + auto info_by_name = layer_info_map(g, _ctx->layers_info); + + auto quantize_dtype = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization dtype + if (iter != info_by_name.end()) + return iter->second.dtype; + + // Return default quantization dtype + return _ctx->output_model_dtype; + }; + + auto quantize_granularity = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization granularity + if (iter != info_by_name.end()) + return iter->second.granularity; + + // Return default quantization granularity + return _ctx->granularity; + }; + + // Quantize weights + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast(node); + QuantizeWeightsWithGPTQ qw(_ctx->input_model_dtype, quantize_dtype(circle_node), + quantize_granularity(circle_node), _hessian_map); + circle_node->accept(&qw); + } + + INFO(l) << "QuantizeWeightsWithGPTQPass End" << std::endl; + return false; // one time run +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp new file mode 100644 index 00000000000..367433115e3 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp @@ -0,0 +1,206 @@ +#include "luci/Pass/QuantizeWeightsWithGPTQPass.h" +#include +#include +#include + +namespace +{ +struct QuantizeWeightsWithGPTQPassTest : public ::testing::Test +{ + /** + * nconv graph + * + * [CircleInput] + * | + * | + * [CircleConv2D] + * | + * | + * [CircleOutput] + */ + void MakeGraph() + { + const int N = 1; + const int H = 4; + const int W = 4; + const int C = 3; // IC = OC + + // graph input and output + auto graph_input = _g.inputs()->create(); + auto graph_output = _g.outputs()->create(); + + // CircleInput + auto input = _g.nodes()->create(); + input->index(graph_input->index()); + input->shape({N, H, W, C}); + input->dtype(loco::DataType::FLOAT32); + input->name("input"); + + // CircleConv2D + auto conv = _g.nodes()->create(); + conv->input(input); + auto bias = _g.nodes()->create(); + bias->dtype(loco::DataType::FLOAT32); + bias->shape({C}); + bias->name("conv_bias"); + conv->bias(bias); + auto weight = _g.nodes()->create(); + weight->dtype(loco::DataType::FLOAT32); + weight->shape({C, H, W, C}); + weight->size(C * H * W * C); + weight->name("nconv/filter"); + conv->filter(weight); + conv->padding(luci::Padding::SAME); + conv->fusedActivationFunction(luci::FusedActFunc::NONE); + conv->dtype(loco::DataType::FLOAT32); + conv->name("nconv"); + + // CircleOutput + auto output = _g.nodes()->create(); + output->index(graph_output->index()); + output->from(conv); + output->shape({N, H, W, C}); + output->dtype(loco::DataType::FLOAT32); + output->name("output"); + } + virtual void SetUp() { MakeGraph(); } + loco::Graph _g; +}; +} // namespace + +TEST_F(QuantizeWeightsWithGPTQPassTest, name) +{ + luci::QuantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::U8, + luci::QuantizationGranularity::ChannelWise); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(QuantizeWeightsWithGPTQPassTest, name_ctx) +{ + auto ctx = std::make_unique(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::ChannelWise; + } + + luci::QuantizeWeightsWithGPTQPass pass(std::move(ctx)); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +// Negative test: Unsupported granularity - Invalid value +TEST_F(QuantizeWeightsWithGPTQPassTest, run_granularity_invalid_NEG) +{ + auto invalid_granularity = static_cast(999); + luci::QuantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::U8, + invalid_granularity); + ASSERT_EXIT(pass.run(&_g), ::testing::KilledBySignal(SIGSEGV), ".*"); +} + +// Negative test: Unsupported output data type - FLOAT32 +TEST_F(QuantizeWeightsWithGPTQPassTest, run_output_f32_NEG) +{ + luci::QuantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::FLOAT32, + luci::QuantizationGranularity::ChannelWise); + // Since output type is FLOAT32, an exception is expected + EXPECT_THROW(pass.run(&_g), std::runtime_error); +} + +// Negative test: Provide an empty hessian map +TEST_F(QuantizeWeightsWithGPTQPassTest, run_with_empty_hessian_map_NEG) +{ + std::unordered_map> hessian_map; + auto ctx = std::make_unique(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::ChannelWise; + } + + luci::QuantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map); + // Expect no exception, pass should handle empty hessian map gracefully + EXPECT_NO_THROW(pass.run(&_g)); +} + +// Negative test: Weights are not of type FLOAT32 +TEST_F(QuantizeWeightsWithGPTQPassTest, run_with_non_float_weights_NEG) +{ + // Change the weights to non-float32 + luci::CircleConst *weight = nullptr; + for (auto node : loco::all_nodes(&_g)) + { + if (auto const_node = dynamic_cast(node)) + { + if (const_node->name() == "nconv/filter") + { + weight = const_node; + break; + } + } + } + ASSERT_NE(weight, nullptr); + // Set dtype to INT32 + weight->dtype(loco::DataType::S32); + + luci::QuantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::U8, + luci::QuantizationGranularity::ChannelWise); + // The pass should skip this node without exception + EXPECT_NO_THROW(pass.run(&_g)); +} + +// Positive test: Run pass with valid hessian map +TEST_F(QuantizeWeightsWithGPTQPassTest, run_with_valid_hessian) +{ + // Create a hessian map with valid data + std::unordered_map> hessian_map; + // Find the conv node + luci::CircleConv2D *conv = nullptr; + for (auto node : loco::all_nodes(&_g)) + { + if (auto conv_node = dynamic_cast(node)) + { + conv = conv_node; + break; + } + } + ASSERT_NE(conv, nullptr); + const auto node_filter = loco::must_cast( + loco::must_cast(conv)->filter()); + // Create a dummy hessian vector + size_t weight_size = node_filter->size(); + std::vector hessian(weight_size * weight_size, 0.0f); + for (size_t i = 0; i < weight_size; ++i) + { + hessian[i * weight_size + i] = 1.0f; // Identity matrix + } + + hessian_map[conv] = hessian; + + auto ctx = std::make_unique(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::ChannelWise; + } + + luci::QuantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map); + EXPECT_NO_THROW(pass.run(&_g)); +} + +// Negative test: Input model data type is U8 (unsupported) +TEST_F(QuantizeWeightsWithGPTQPassTest, run_input_U8_NEG) +{ + luci::QuantizeWeightsWithGPTQPass pass(loco::DataType::U8, loco::DataType::U8, + luci::QuantizationGranularity::ChannelWise); + EXPECT_THROW(pass.run(&_g), std::runtime_error); +} + +// Negative test: Output model data type is S32 (unsupported) +TEST_F(QuantizeWeightsWithGPTQPassTest, run_output_S32_NEG) +{ + luci::QuantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::S32, + luci::QuantizationGranularity::ChannelWise); + EXPECT_THROW(pass.run(&_g), std::runtime_error); +} From 313d3b8ba07e2b71ce92a999e14d78989177d7ef Mon Sep 17 00:00:00 2001 From: "y01000.you" Date: Thu, 31 Oct 2024 17:21:51 +0900 Subject: [PATCH 2/6] [luci-pass] Add cal_offset_2d Add a function for calculating offset of 2d tensor. ONE-DCO-1.0-Signed-off-by: young cheon --- compiler/luci/pass/src/QuantizationUtils.cpp | 5 +++++ compiler/luci/pass/src/QuantizationUtils.h | 3 +++ 2 files changed, 8 insertions(+) diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index 33d75ce04b4..e3099936100 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -292,6 +292,11 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices) indices[2] * dimension.dim(3).value() + indices[3]; } +uint32_t cal_offset_2d(loco::TensorShape &dimension, uint32_t *indices) +{ + return indices[0] * dimension.dim(1).value() + indices[1]; +} + // Activation (ofm) qtype is determined in different ways. // 1. Pre-defined values: Some Ops have pre-defined qparams (ex: LOGISTIC, TANH) // 2. Integer scale: Output of some Ops should be integers (ex: FLOOR, CEIL) diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h index 0bf3270d5ec..290ee1786db 100644 --- a/compiler/luci/pass/src/QuantizationUtils.h +++ b/compiler/luci/pass/src/QuantizationUtils.h @@ -53,6 +53,7 @@ bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, // Calculate offset of the given indices in dimension uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices); +uint32_t cal_offset_2d(loco::TensorShape &dimension, uint32_t *indices); // Backward propagation of concatenation qparam void propagate_concat_quantparam(luci::CircleConcatenation *concat); @@ -63,6 +64,8 @@ void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2); // Return true if the node is quantized bool is_quantized(const CircleNode *node); +uint8_t fp32_to_uint8_cast(float f); + // Return true if the node is fp32 bool is_fp32(const CircleNode *node); From bce1b10b1c5b1e317217896febedc5e5d17ad742 Mon Sep 17 00:00:00 2001 From: "y01000.you" Date: Thu, 31 Oct 2024 18:49:15 +0900 Subject: [PATCH 3/6] Refactoring and clean-up This patch includes minor refactoring to improve overall code quality. ONE-DCO-1.0-Signed-off-by: y01000.you --- .../luci/Pass/QuantizeWeightsWithGPTQPass.h | 14 ++-- .../pass/src/QuantizeWeightsWithGPTQPass.cpp | 76 ++++++++++--------- .../src/QuantizeWeightsWithGPTQPass.test.cpp | 1 + 3 files changed, 49 insertions(+), 42 deletions(-) diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h index afef945bf98..a4ac4620404 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h @@ -17,17 +17,19 @@ #ifndef __LUCI_QUANTIZE_WEIGHTS_WITH_GPTQ_PASS_H__ #define __LUCI_QUANTIZE_WEIGHTS_WITH_GPTQ_PASS_H__ -#include +#include +#include #include +#include -#include -#include #include namespace luci { +using HessianMap = std::unordered_map>; + /** * @brief Pass to quantize weights */ @@ -48,9 +50,7 @@ class QuantizeWeightsWithGPTQPass : public logo::Pass // DO NOTHING } - QuantizeWeightsWithGPTQPass( - std::unique_ptr &&ctx, - std::unordered_map> *hessian_map) + QuantizeWeightsWithGPTQPass(std::unique_ptr &&ctx, HessianMap *hessian_map) : _ctx{std::move(ctx)}, _hessian_map{hessian_map} { // DO NOTHING @@ -74,7 +74,7 @@ class QuantizeWeightsWithGPTQPass : public logo::Pass private: std::unique_ptr _ctx; - std::unordered_map> *_hessian_map; + HessianMap *_hessian_map; }; } // namespace luci diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp index cd9267117a0..79e5a2bf9c2 100644 --- a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved * Copyright 2019 The TensorFlow Authors. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,16 +27,19 @@ #include #include #include -#include + +namespace luci +{ namespace { -using namespace luci; using IterFunc = std::function; void iterate_per_channel_with_order(CircleConst *node, IterFunc func, bool reverse) { + assert(node != nullptr); + loco::TensorShape dimension; dimension.rank(4); uint32_t indices[4] = {0}; @@ -44,7 +47,7 @@ void iterate_per_channel_with_order(CircleConst *node, IterFunc func, bool rever uint32_t num_dims[4]; if (!get_channel_dim_index(node, dimension, channel_dim_index)) { - throw std::runtime_error("Failed to get channel dim index."); + throw std::runtime_error("GPTQPass: Failed to get channel dim index."); } auto order = reverse ? std::vector{3, 1, 2, 0} : std::vector{0, 1, 2, 3}; @@ -73,19 +76,15 @@ void iterate_per_channel_with_order(CircleConst *node, IterFunc func, bool rever } } -} // namespace - -namespace luci -{ - -namespace -{ - size_t calculate_qauntized_value(CircleConst *node, uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index, std::vector &scaling_factor, std::vector &max, std::vector &min) { + assert(node != nullptr); + int channel_idx = indices[channel_dim_index]; + + assert(scaling_factor[channel_idx] > 0); const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; auto data = node->at(cal_offset(dimension, indices)); auto data_clipped = data < min[channel_idx] ? min[channel_idx] : data; @@ -109,7 +108,7 @@ void cholesky_decomposition(std::vector &src, uint32_t num_size) { if (src[i * num_size + i] - sum <= 0) { - std::cout << "Error: Matrix is not positive definite.\n" << std::endl; + std::cout << "Error: Matrix is not positive definite." << std::endl; return; } src[i * num_size + i] = sqrt(src[i * num_size + i] - sum); @@ -143,6 +142,7 @@ void forward_substitution(const std::vector &L, const std::vector { y[i] -= L[i * num_size + j] * y[j]; } + assert(L[i * num_size + i] != 0); y[i] /= L[i * num_size + i]; } } @@ -157,6 +157,7 @@ void backward_substitution(const std::vector &U, const std::vector { x[i] -= U[i * num_size + j] * x[j]; } + assert(U[i * num_size + i] != 0); x[i] /= U[i * num_size + i]; } } @@ -222,8 +223,7 @@ void cal_minmax_per_channel(CircleConst *node, std::vector &min, std::vec if (!get_channel_dim_index(node, dimension, channel_dim_index)) { - assert(false); - return; + throw std::runtime_error("GPTQPass: Failed to get channel dim index."); } auto size = dimension.dim(channel_dim_index).value(); @@ -262,10 +262,13 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t const double qmax_double = kMaxScale; const double rmin = std::fmin(0, min); const double rmax = std::fmax(0, max); + const double qrange = qmax_double - qmin_double; + assert(qrange > 0); - double scale = (rmax - rmin) / (qmax_double - qmin_double); + double scale = (rmax - rmin) / qrange; double zero_point_double = 0; uint8_t nudged_zero_point = 0; + if (scale == 0) { WARN(l) << "The minimum and maximum values are the same." << std::endl; @@ -280,7 +283,7 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t { assert(min >= 0 && max >= 0); nudged_zero_point = kMinScale; - scale = max / (qmax_double - qmin_double); + scale = max / qrange; if (min > 0 && max > 0) WARN(l) << "The minimum and maximum values are all positive." << std::endl; } @@ -288,7 +291,7 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t { assert(min < 0 && max < 0); nudged_zero_point = kMaxScale; - scale = -min / (qmax_double - qmin_double); + scale = -min / qrange; WARN(l) << "The minimum and maximum values are all negative." << std::endl; } else @@ -318,6 +321,7 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, std::vector &hessian) { assert(node->dtype() == loco::DataType::FLOAT32); + assert(output_type == loco::DataType::U8 || output_type != loco::DataType::U4); IterFunc quantize; @@ -333,7 +337,7 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, kMaxScale); } - if (hessian.empty()) // Cases where gptq is not applied + if (hessian.empty()) // Case where GPTQ is not applied { quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { quantized_values[cal_offset(dimension, indices)] = calculate_qauntized_value( @@ -341,7 +345,7 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, }; iterate_per_channel_with_order(node, quantize, false); } - else // Cases where gptq is applied + else // Case where GPTQ is applied { uint32_t size_hessian = static_cast(sqrt(hessian.size())); float percdamp = .01; @@ -364,7 +368,7 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, cholesky_inverse(hessian, size_hessian); cholesky_decomposition(hessian, size_hessian); - // transpose hessian to make upper trangular + // transpose hessian to make upper triangular for (uint32_t i = 0; i < size_hessian; i++) { for (uint32_t j = 0; j < i; j++) @@ -405,10 +409,11 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, uint32_t channel_idx = indices[channel_dim_index]; auto data = node->at(cal_offset(dimension, indices)); + auto h_offset = cal_offset_2d(dimension_hessian, indices_diag_hessian); error[cal_offset(dimension, indices)] = (data - (quantized_values[cal_offset(dimension, indices)] - zp[channel_idx]) * scaling_factor[channel_idx]) / - hessian[cal_offset_2d(dimension_hessian, indices_diag_hessian)]; + hessian[h_offset]; if (channel_idx == (dimension.dim(channel_dim_index).value() - 1)) { @@ -426,10 +431,10 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, uint32_t idx_ihw = dimension_channel_last[2] * dimension_channel_last[3] * i + dimension_channel_last[3] * h + w; uint32_t indices_hessain[2] = {idx_quant_column, idx_ihw}; + auto _h_offset = cal_offset_2d(dimension_hessian, indices_hessain); node->at(cal_offset(dimension, indices_channel_first)) -= - error[cal_offset(dimension, indices_error)] * - hessian[cal_offset_2d(dimension_hessian, indices_hessain)]; + error[cal_offset(dimension, indices_error)] * hessian[_h_offset]; } } } @@ -439,8 +444,8 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, iterate_per_channel_with_order(node, quantize, true); } - node->dtype(loco::DataType::U8); // change the type of tensor - node->size(size); // resize tensor + node->dtype(loco::DataType::U8); // Change the type of tensor + node->size(size); // Resize tensor for (uint32_t i = 0; i < size; ++i) { node->at(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); @@ -472,27 +477,28 @@ void asymmetric_wdequant_per_channel(CircleConst *node, std::vector &scal } /** - * @brief QuantizeDequantizeWeights quantizes and dequantizes tensors for weights - * @details Find min/max values on the fly, quantize the model, and dequantize the model + * @brief QuantizeWeightsWithGPTQ quantizes and dequantizes tensors for weights uisng GPTQ algorithm + * @details Compensate for the quantization error and update weights using Hessian matrix + * */ -struct QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitor +class QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitor { +public: QuantizeWeightsWithGPTQ( loco::DataType input, loco::DataType output, QuantizationGranularity granularity, std::unordered_map> *hessian_map) - : input_type(input), output_type(output), granularity(granularity), hessian_map(hessian_map) + : input_type(input), output_type(output), granularity(granularity), _hessian_map(hessian_map) { } +private: loco::DataType input_type; loco::DataType output_type; QuantizationGranularity granularity; - std::unordered_map> *hessian_map; + std::unordered_map> *_hessian_map; -private: void fake_quantize_cwq(luci::CircleConst *weights, std::vector &hessian) const { - // assert(output_type == loco::DataType::U8); // FIX_CALLER_UNLESS if (output_type != loco::DataType::U8) { throw std::runtime_error("GPTQ quantization supports u8"); @@ -565,7 +571,7 @@ struct QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitorfilter(new_weights); - auto hessian = (*hessian_map)[node]; + auto hessian = (*_hessian_map)[node]; fake_quantize(new_weights, hessian); } @@ -615,7 +621,7 @@ struct QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitorweights(new_weights); - auto hessian = (*hessian_map)[node]; + auto hessian = (*_hessian_map)[node]; fake_quantize(new_weights, hessian); } diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp index 367433115e3..8c00e46bcab 100644 --- a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp @@ -5,6 +5,7 @@ namespace { + struct QuantizeWeightsWithGPTQPassTest : public ::testing::Test { /** From e0fb564b4fb6e0e301c442f721a2f74a68b40c5e Mon Sep 17 00:00:00 2001 From: "y01000.you" Date: Fri, 1 Nov 2024 10:23:19 +0900 Subject: [PATCH 4/6] Add description, and extract variables in GPTQPass This commit adds a description to the QuantizeWeightsWithGPTQPass class. Additionally, some variables have been extracted from the function parameters to improve readability and maintainability. ONE-DCO-1.0-Signed-off-by: y01000.you --- .../luci/Pass/QuantizeWeightsWithGPTQPass.h | 4 +- .../pass/src/QuantizeWeightsWithGPTQPass.cpp | 57 ++++++++++++------- .../src/QuantizeWeightsWithGPTQPass.test.cpp | 2 + 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h index a4ac4620404..2f4ca0ecf63 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsWithGPTQPass.h @@ -31,7 +31,7 @@ namespace luci using HessianMap = std::unordered_map>; /** - * @brief Pass to quantize weights + * @brief Pass to quantize weights with GPTQ algorithm */ class QuantizeWeightsWithGPTQPass : public logo::Pass { @@ -74,7 +74,7 @@ class QuantizeWeightsWithGPTQPass : public logo::Pass private: std::unique_ptr _ctx; - HessianMap *_hessian_map; + HessianMap *_hessian_map = nullptr; }; } // namespace luci diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp index 79e5a2bf9c2..1fa7d45b961 100644 --- a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp @@ -129,6 +129,7 @@ void cholesky_decomposition(std::vector &src, uint32_t num_size) } } } + return; } @@ -250,8 +251,18 @@ void cal_minmax_per_channel(CircleConst *node, std::vector &min, std::vec iterate_per_channel_with_order(node, cal_minmax, false); } -void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, - float &nudged_min, float &nudged_max, int32_t k_max_scale) +/** + * @brief Compute the scale and zero point for the given range of values + * @param min: The minimum value in the range of values to be quantized. + * @param max: The maximum value in the range of values to be quantized. + * @param k_max_scale: The maximum value of the quantization scale. + * @param scaling_factor: The computed scaling factor for the quantization. + * @param zp: The computed zero point for the quantization. + * @param nudged_min: The nudged minimum value after applying the scaling factor and zero point. + * @param nudged_max: The nudged maximum value after applying the scaling factor and zero point. + */ +void compute_asym_scale_zp(float min, float max, int32_t k_max_scale, float &scaling_factor, + int64_t &zp, float &nudged_min, float &nudged_max) { LOGGER(l); @@ -333,8 +344,8 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, for (size_t i = 0; i < min.size(); ++i) { - compute_asym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i], - kMaxScale); + compute_asym_scale_zp(min[i], max[i], kMaxScale, scaling_factor[i], zp[i], nudged_min[i], + nudged_max[i]); } if (hessian.empty()) // Case where GPTQ is not applied @@ -355,6 +366,7 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, { damp += hessian[i * size_hessian + i]; } + assert(size_hessian != 0); damp /= size_hessian; damp *= percdamp; @@ -407,13 +419,15 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, uint32_t indices_diag_hessian[2] = {idx_quant_column, idx_quant_column}; uint32_t channel_idx = indices[channel_dim_index]; - auto data = node->at(cal_offset(dimension, indices)); - auto h_offset = cal_offset_2d(dimension_hessian, indices_diag_hessian); - error[cal_offset(dimension, indices)] = - (data - (quantized_values[cal_offset(dimension, indices)] - zp[channel_idx]) * - scaling_factor[channel_idx]) / - hessian[h_offset]; + auto data_indices = cal_offset(dimension, indices); + auto hessian_indices = cal_offset_2d(dimension_hessian, indices_diag_hessian); + + auto data = node->at(data_indices); + auto quantized_rvalue = + (quantized_values[data_indices] - zp[channel_idx]) * scaling_factor[channel_idx]; + + error[data_indices] = (data - quantized_rvalue) / hessian[hessian_indices]; if (channel_idx == (dimension.dim(channel_dim_index).value() - 1)) { @@ -431,10 +445,13 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, uint32_t idx_ihw = dimension_channel_last[2] * dimension_channel_last[3] * i + dimension_channel_last[3] * h + w; uint32_t indices_hessain[2] = {idx_quant_column, idx_ihw}; - auto _h_offset = cal_offset_2d(dimension_hessian, indices_hessain); - node->at(cal_offset(dimension, indices_channel_first)) -= - error[cal_offset(dimension, indices_error)] * hessian[_h_offset]; + auto _h_indices = cal_offset_2d(dimension_hessian, indices_hessain); + auto _data_indices = cal_offset(dimension, indices_channel_first); + auto _error_indices = cal_offset(dimension, indices_error); + + node->at(_data_indices) -= + error[_error_indices] * hessian[_h_indices]; } } } @@ -487,19 +504,19 @@ class QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitor> *hessian_map) - : input_type(input), output_type(output), granularity(granularity), _hessian_map(hessian_map) + : _input_type(input), _output_type(output), _granularity(granularity), _hessian_map(hessian_map) { } private: - loco::DataType input_type; - loco::DataType output_type; - QuantizationGranularity granularity; + loco::DataType _input_type; + loco::DataType _output_type; + QuantizationGranularity _granularity; std::unordered_map> *_hessian_map; void fake_quantize_cwq(luci::CircleConst *weights, std::vector &hessian) const { - if (output_type != loco::DataType::U8) + if (_output_type != loco::DataType::U8) { throw std::runtime_error("GPTQ quantization supports u8"); } @@ -515,7 +532,7 @@ class QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitor zp(min.size()); asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max, - output_type, hessian); + _output_type, hessian); asymmetric_wdequant_per_channel(weights, scaling_factor, nudged_min); auto quantparam = std::make_unique(); @@ -529,7 +546,7 @@ class QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitor &hessian) const { - switch (granularity) + switch (_granularity) { case luci::QuantizationGranularity::ChannelWise: fake_quantize_cwq(weights, hessian); diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp index 8c00e46bcab..c2b7b3b30b9 100644 --- a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.test.cpp @@ -1,6 +1,8 @@ #include "luci/Pass/QuantizeWeightsWithGPTQPass.h" #include + #include + #include namespace From 9387b968249fd3d8b7ecc69301dac4c3915ae649 Mon Sep 17 00:00:00 2001 From: "y01000.you" Date: Fri, 1 Nov 2024 16:16:56 +0900 Subject: [PATCH 5/6] Add GPTQ algorithm and hessian map support for circle quantizer This commit adds support for the GPTQ algorithm and hessian map to the CircleQuantizer class in LUCI. ONE-DCO-1.0-Signed-off-by: y01000.you --- compiler/luci/pass/include/luci/CircleQuantizer.h | 10 ++++++++++ compiler/luci/pass/src/QuantizationUtils.cpp | 3 +++ compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp | 4 ++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/compiler/luci/pass/include/luci/CircleQuantizer.h b/compiler/luci/pass/include/luci/CircleQuantizer.h index 3ec50986ec6..29ce5a26b73 100644 --- a/compiler/luci/pass/include/luci/CircleQuantizer.h +++ b/compiler/luci/pass/include/luci/CircleQuantizer.h @@ -17,15 +17,19 @@ #ifndef __LUCI_CIRCLE_QUANTIZER_H__ #define __LUCI_CIRCLE_QUANTIZER_H__ +#include #include #include #include +#include #include namespace luci { +using HessianMap = std::unordered_map>; + class CircleQuantizer final { public: @@ -59,6 +63,7 @@ class CircleQuantizer final enum Algorithm { QuantizeDequantizeWeights, + QuantizeWeightsWithGPTQ, QuantizeWithMinMax, Requantize, CopyQuantParam, @@ -111,6 +116,10 @@ class CircleQuantizer final public: void quantize(loco::Graph *) const; + void setHessianMap(std::unique_ptr &hessian_map) + { + _hessian_map = std::move(hessian_map); + } private: void quantize_dequantize_weight(loco::Graph *) const; @@ -124,6 +133,7 @@ class CircleQuantizer final private: std::unique_ptr _options; + std::unique_ptr _hessian_map; }; } // namespace luci diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index e3099936100..143b12ccb29 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -294,6 +294,9 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices) uint32_t cal_offset_2d(loco::TensorShape &dimension, uint32_t *indices) { + assert(dimension.rank() == 2); + assert(sizeof(indices) / sizeof(*indices) == dimension.rank()); + return indices[0] * dimension.dim(1).value() + indices[1]; } diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp index 1fa7d45b961..47b3cb0ad26 100644 --- a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp @@ -516,9 +516,9 @@ class QuantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVisitor &hessian) const { - if (_output_type != loco::DataType::U8) + if (_output_type != loco::DataType::U4 || _output_type != loco::DataType::U8) { - throw std::runtime_error("GPTQ quantization supports u8"); + throw std::runtime_error("GPTQ quantization supports U4/U8"); } // Find min/max per channel std::vector min; From 3574826f453de9f7d2736bd5fdefdc3002cf2e0b Mon Sep 17 00:00:00 2001 From: "y01000.you" Date: Fri, 1 Nov 2024 19:26:21 +0900 Subject: [PATCH 6/6] Refactor GPTQPass to improve readability and maintainability This commit refactors the QuantizeWeightsWithGPTQPass.cpp file to improve its readability and maintainability. ONE-DCO-1.0-Signed-off-by: y01000.you --- .../pass/src/QuantizeWeightsWithGPTQPass.cpp | 167 +++++++++--------- 1 file changed, 88 insertions(+), 79 deletions(-) diff --git a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp index 47b3cb0ad26..6f20f083a23 100644 --- a/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp +++ b/compiler/luci/pass/src/QuantizeWeightsWithGPTQPass.cpp @@ -43,9 +43,9 @@ void iterate_per_channel_with_order(CircleConst *node, IterFunc func, bool rever loco::TensorShape dimension; dimension.rank(4); uint32_t indices[4] = {0}; - int32_t channel_dim_index{0}; + int32_t index_channel_dim{0}; uint32_t num_dims[4]; - if (!get_channel_dim_index(node, dimension, channel_dim_index)) + if (!get_channel_dim_index(node, dimension, index_channel_dim)) { throw std::runtime_error("GPTQPass: Failed to get channel dim index."); } @@ -69,7 +69,7 @@ void iterate_per_channel_with_order(CircleConst *node, IterFunc func, bool rever indices[order[1]] = j; indices[order[2]] = s; indices[order[3]] = t; - func(indices, dimension, channel_dim_index); + func(indices, dimension, index_channel_dim); } } } @@ -77,20 +77,20 @@ void iterate_per_channel_with_order(CircleConst *node, IterFunc func, bool rever } size_t calculate_qauntized_value(CircleConst *node, uint32_t *indices, loco::TensorShape &dimension, - int channel_dim_index, std::vector &scaling_factor, + int index_channel_dim, std::vector &scaling_factor, std::vector &max, std::vector &min) { assert(node != nullptr); - int channel_idx = indices[channel_dim_index]; + int idx_channel = indices[index_channel_dim]; - assert(scaling_factor[channel_idx] > 0); - const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + assert(scaling_factor[idx_channel] > 0); + const float scaling_factor_inv = 1.0 / scaling_factor[idx_channel]; auto data = node->at(cal_offset(dimension, indices)); - auto data_clipped = data < min[channel_idx] ? min[channel_idx] : data; - data_clipped = data_clipped > max[channel_idx] ? max[channel_idx] : data_clipped; + auto data_clipped = data < min[idx_channel] ? min[idx_channel] : data; + data_clipped = data_clipped > max[idx_channel] ? max[idx_channel] : data_clipped; - return static_cast(std::round((data_clipped - min[channel_idx]) * scaling_factor_inv)); + return static_cast(std::round((data_clipped - min[idx_channel]) * scaling_factor_inv)); } void cholesky_decomposition(std::vector &src, uint32_t num_size) @@ -220,31 +220,31 @@ void cal_minmax_per_channel(CircleConst *node, std::vector &min, std::vec { loco::TensorShape dimension; dimension.rank(4); - int32_t channel_dim_index{0}; + int32_t index_channel_dim{0}; - if (!get_channel_dim_index(node, dimension, channel_dim_index)) + if (!get_channel_dim_index(node, dimension, index_channel_dim)) { throw std::runtime_error("GPTQPass: Failed to get channel dim index."); } - auto size = dimension.dim(channel_dim_index).value(); + auto size = dimension.dim(index_channel_dim).value(); std::vector has_min_max_value(size, false); min.resize(size); max.resize(size); - auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { - int channel_idx = indices[channel_dim_index]; + auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int index_channel_dim) { + int idx_channel = indices[index_channel_dim]; auto data = node->at(cal_offset(dimension, indices)); - if (has_min_max_value[channel_idx]) + if (has_min_max_value[idx_channel]) { - min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx]; - max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx]; + min[idx_channel] = data < min[idx_channel] ? data : min[idx_channel]; + max[idx_channel] = data > max[idx_channel] ? data : max[idx_channel]; } else { - min[channel_idx] = data; - max[channel_idx] = data; - has_min_max_value[channel_idx] = true; + min[idx_channel] = data; + max[idx_channel] = data; + has_min_max_value[idx_channel] = true; } }; @@ -325,6 +325,39 @@ void compute_asym_scale_zp(float min, float max, int32_t k_max_scale, float &sca zp = nudged_zero_point; } +void apply_damping_to_hessian(std::vector &hessian, uint32_t num_size) +{ + float damp = 0; + float percdamp = .01; + + for (uint32_t i = 0; i < num_size; i++) + { + damp += hessian[i * num_size + i]; + } + + assert(num_size != 0); + damp /= num_size; + damp *= percdamp; + + for (uint32_t i = 0; i < num_size; i++) + { + hessian[i * num_size + i] += damp; + } +} + +void transpose_to_upper_triangular(std::vector &matrix, uint32_t num_size) +{ + for (uint32_t i = 0; i < num_size; i++) + { + for (uint32_t j = 0; j < i; j++) + { + float tmp = matrix[i * num_size + j]; + matrix[i * num_size + j] = matrix[j * num_size + i]; + matrix[j * num_size + i] = tmp; + } + } +} + void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, std::vector &max, std::vector &scaling_factor, std::vector &zp, std::vector &nudged_min, @@ -339,8 +372,8 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, const int32_t kMinScale = 0; const int32_t kMaxScale = output_type == loco::DataType::U4 ? 15 : 255; - uint32_t size = node->size(); - std::vector quantized_values(size); + uint32_t input_size = node->size(); + std::vector quantized_values(input_size); for (size_t i = 0; i < min.size(); ++i) { @@ -350,48 +383,24 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, if (hessian.empty()) // Case where GPTQ is not applied { - quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int index_channel_dim) { quantized_values[cal_offset(dimension, indices)] = calculate_qauntized_value( - node, indices, dimension, channel_dim_index, scaling_factor, nudged_max, nudged_min); + node, indices, dimension, index_channel_dim, scaling_factor, nudged_max, nudged_min); }; iterate_per_channel_with_order(node, quantize, false); } else // Case where GPTQ is applied { uint32_t size_hessian = static_cast(sqrt(hessian.size())); - float percdamp = .01; - float damp = 0; - - for (uint32_t i = 0; i < size_hessian; i++) - { - damp += hessian[i * size_hessian + i]; - } - assert(size_hessian != 0); - damp /= size_hessian; - damp *= percdamp; - - for (uint32_t i = 0; i < size_hessian; i++) - { - hessian[i * size_hessian + i] += damp; - } - // calculate hessian inverse + // Calculate hessian inverse + apply_damping_to_hessian(hessian, size_hessian); cholesky_decomposition(hessian, size_hessian); cholesky_inverse(hessian, size_hessian); cholesky_decomposition(hessian, size_hessian); + transpose_to_upper_triangular(hessian, size_hessian); - // transpose hessian to make upper triangular - for (uint32_t i = 0; i < size_hessian; i++) - { - for (uint32_t j = 0; j < i; j++) - { - float tmp = hessian[i * size_hessian + j]; - hessian[i * size_hessian + j] = hessian[j * size_hessian + i]; - hessian[j * size_hessian + i] = tmp; - } - } - - std::vector error(size); + std::vector error(input_size); loco::TensorShape dimension_channel_last; dimension_channel_last.rank(4); @@ -401,35 +410,34 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, dimension_hessian.dim(0).set(size_hessian); dimension_hessian.dim(1).set(size_hessian); - quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { - quantized_values[cal_offset(dimension, indices)] = calculate_qauntized_value( - node, indices, dimension, channel_dim_index, scaling_factor, nudged_max, nudged_min); + quantize = [&](uint32_t *indices, loco::TensorShape &dimension_input, int index_channel_dim) { + quantized_values[cal_offset(dimension_input, indices)] = calculate_qauntized_value( + node, indices, dimension_input, index_channel_dim, scaling_factor, nudged_max, nudged_min); uint32_t indices_channel_last[4] = { indices[0], indices[3], indices[1], indices[2] // ohwi -> oihw }; - - uint32_t dimension_channel_last[4] = {dimension.dim(0).value(), dimension.dim(3).value(), - dimension.dim(1).value(), dimension.dim(2).value()}; + uint32_t dimension_channel_last[4] = { + dimension_input.dim(0).value(), dimension_input.dim(3).value(), + dimension_input.dim(1).value(), dimension_input.dim(2).value()}; uint32_t idx_quant_column = dimension_channel_last[2] * dimension_channel_last[3] * indices_channel_last[1] + dimension_channel_last[3] * indices_channel_last[2] + indices_channel_last[3]; + uint32_t idx_channel = indices[index_channel_dim]; uint32_t indices_diag_hessian[2] = {idx_quant_column, idx_quant_column}; - uint32_t channel_idx = indices[channel_dim_index]; + auto idx_input_data = cal_offset(dimension_input, indices); + auto idx_hessian = cal_offset_2d(dimension_hessian, indices_diag_hessian); - auto data_indices = cal_offset(dimension, indices); - auto hessian_indices = cal_offset_2d(dimension_hessian, indices_diag_hessian); - - auto data = node->at(data_indices); + auto input_data = node->at(idx_input_data); auto quantized_rvalue = - (quantized_values[data_indices] - zp[channel_idx]) * scaling_factor[channel_idx]; + (quantized_values[idx_input_data] - zp[idx_channel]) * scaling_factor[idx_channel]; - error[data_indices] = (data - quantized_rvalue) / hessian[hessian_indices]; + error[idx_input_data] = (input_data - quantized_rvalue) / hessian[idx_hessian]; - if (channel_idx == (dimension.dim(channel_dim_index).value() - 1)) + if (idx_channel == (dimension_input.dim(index_channel_dim).value() - 1)) { for (uint32_t o = 0; o < dimension_channel_last[0]; o++) { @@ -439,19 +447,20 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, { for (uint32_t w = 0; w < dimension_channel_last[3]; w++) { - // convert coordination + // Convert coordination uint32_t indices_channel_first[4] = {o, h, w, i}; uint32_t indices_error[4] = {o, indices[1], indices[2], indices[3]}; uint32_t idx_ihw = dimension_channel_last[2] * dimension_channel_last[3] * i + dimension_channel_last[3] * h + w; uint32_t indices_hessain[2] = {idx_quant_column, idx_ihw}; - auto _h_indices = cal_offset_2d(dimension_hessian, indices_hessain); - auto _data_indices = cal_offset(dimension, indices_channel_first); - auto _error_indices = cal_offset(dimension, indices_error); + auto _idx_h = cal_offset_2d(dimension_hessian, indices_hessain); + auto _idx_input_data = cal_offset(dimension_input, indices_channel_first); + auto _idx_error = cal_offset(dimension_input, indices_error); - node->at(_data_indices) -= - error[_error_indices] * hessian[_h_indices]; + // Compensate quantize error + node->at(_idx_input_data) -= + error[_idx_error] * hessian[_idx_h]; } } } @@ -461,9 +470,9 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector &min, iterate_per_channel_with_order(node, quantize, true); } - node->dtype(loco::DataType::U8); // Change the type of tensor - node->size(size); // Resize tensor - for (uint32_t i = 0; i < size; ++i) + node->dtype(loco::DataType::U8); // Change the type of tensor + node->size(input_size); // Resize tensor + for (uint32_t i = 0; i < input_size; ++i) { node->at(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); } @@ -476,11 +485,11 @@ void asymmetric_wdequant_per_channel(CircleConst *node, std::vector &scal uint32_t size = node->size(); std::vector dequantized_values(size); - auto dequantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { - int channel_idx = indices[channel_dim_index]; + auto dequantize = [&](uint32_t *indices, loco::TensorShape &dimension, int index_channel_dim) { + int idx_channel = indices[index_channel_dim]; auto data = node->at(cal_offset(dimension, indices)); dequantized_values[cal_offset(dimension, indices)] = - static_cast(data) * scaling_factor[channel_idx] + nudged_min[channel_idx]; + static_cast(data) * scaling_factor[idx_channel] + nudged_min[idx_channel]; }; iterate_per_channel_with_order(node, dequantize, false);