Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) bitnet and t-mac #23540

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

(WIP) bitnet and t-mac #23540

wants to merge 6 commits into from

Conversation

liqunfu
Copy link
Contributor

@liqunfu liqunfu commented Jan 30, 2025

Preparation for 2bit T-MAC and ternary bit BitNet implementation.

  1. 2bit T-MAC implementation to be added at onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp and qnbitgemm_kernel_neon.cpp (Q2BitGemmXXX).
  • Q2BitGemmPackQuantBDataSize returns size of packed quant weight so that mlas allocates memory for it.
  • SQ2BitGemmPackQuantBData does packing of quantized weights. See SQ4BitGemmPackQuantBData for reference.
  • Q2BitGemmPerGemmWorkspaceSize returns size of workspace needed for activation. It is likely the same as for 4bit.
  • SQ2BitGemmKernel_CompInt8 does the matmul compute. It takes quantA, quantB, computes output. SQ2BitGemmKernel_CompInt8 shall be called from SQ2BitGemm_CompInt8 which need to be implemented too. see SQ4BitGemm_CompInt8 for reference.
  1. BitNet implementation can be added later.
  2. Tests for mlas function is at onnxruntime\test\mlas\unittest\test_sqnbitgemm.cpp by uncommenting SQNBitGemmShortExecuteTest<2, blklen>::RegisterShortExecuteTests();
  3. matmulnbit kernel implementation is at: onnxruntime\contrib_ops\cpu\quantization\matmul_nbits.cc
  4. tests for matmulnbit kernel are at onnxruntime\test\contrib_ops\matmul_4bits_test.cc by enabling DISABLED_Float32_Accuracy4_Q2.

@liqunfu liqunfu requested a review from a team as a code owner January 30, 2025 03:12
@jywu-msft
Copy link
Member

Can you add description/context ?

@liqunfu liqunfu changed the title bitnet and t-mac (WIP) bitnet and t-mac Jan 31, 2025
@@ -402,7 +402,8 @@
struct BlockwiseQuantizer {
// To support other qbits, need to add bit packing code for
// storing to dst and zero points
static_assert(qbits == 4, "Only 4b block quantization is supported!");
static_assert(qbits == 4 || qbits == 2, "Only 4b block quantization is supported!");
//static_assert(qbits != 2 || Columnwise, "Only support Columnwise in qbits == 2 case.");

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Signed-off-by: Liqun Fu <[email protected]>
Comment on lines +63 to +74
switch (ComputeType) {
case SQNBIT_CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
// QuantData + Scale
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
}

Check notice

Code scanning / CodeQL

No trivial switch statements Note

This switch statement should either handle more cases, or be rewritten as an if statement.
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
}
const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4;
const float v1 = (static_cast<float>(vi1) - zp1) * scale1;
dst[j * rows + (i + 1)] = static_cast<ElementT>(v1);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe separate the template specializations for qnbits so the code would be cleaner?

@fajin-corp
Copy link
Contributor

                    range2scale<Tin, 4, signed_quant>(vmin_t[i + 1], vmax_t[i + 1], scale1_tt);

this might be wrong


Refers to: onnxruntime/core/mlas/lib/q4_dq.cpp:983 in b4aad01. [](commit_id = b4aad01, deletion_comment = False)


template <typename Tin, bool signed_quant>
struct BlockwiseQDQQuantizer<Tin, 4, signed_quant> {
struct BlockwiseQDQQuantizer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to separate the specializations for different qbits?

}

size_t
SQ2BitGemmKernel_CompInt8_avx2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQ2BitGemmKernel_CompInt8_avx2

avx2 kernel should not appear in this file

@fajin-corp
Copy link
Contributor

typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)(

rename to SQNBitGemmKernel_CompInt8_Fn?


Refers to: onnxruntime/core/mlas/lib/qnbitgemm.h:338 in b4aad01. [](commit_id = b4aad01, deletion_comment = False)

@fajin-corp
Copy link
Contributor

typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)(

rename to QNBitGemmPackQuantBDataSize_Fn?


Refers to: onnxruntime/core/mlas/lib/qnbitgemm.h:94 in b4aad01. [](commit_id = b4aad01, deletion_comment = False)

@@ -113,6 +120,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH {

Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* SQ2BitGemmPackQuantBData = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q4BitGemmPackQuantBData_Fn

rename to QNBitGemmPackQuantBData_Fn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants