Skip to content

Commit

Permalink
single attention增加汇编
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyuyang committed Mar 22, 2024
1 parent 8aa007f commit 932dee0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ token
/example/Android/LLMAssistant/local.properties
/test/cmmlu/results/
/models/
/localtest/
41 changes: 35 additions & 6 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,44 @@ namespace fastllm {
qk[j] = -10000;
continue;
}
float sum = 0.0f;
for (int l = 0; l < q2; l++) {
sum += qd[i * q2 + l] * kd[j * q2 + l];
float now = 0.0f;
int l = 0;
#ifdef __aarch64__
float32x4_t sum = {0, 0, 0, 0};
for (; l + 3 < q2; l += 4) {
sum = vaddq_f32(sum, vmulq_f32(vld1q_f32(qd + i * q2 + l),
vld1q_f32(kd + j * q2 + l)));
}
now += sum[0] + sum[1] + sum[2] + sum[3];
#elif defined(__AVX__)
__m256 vsum = _mm256_set1_ps(0.0f);
for (; l + 7 < q2; l += 8) {
__m256 vx = _mm256_loadu_ps((const float *) (qd + i * q2 + l));
__m256 vy = _mm256_loadu_ps((const float *) (kd + j * q2 + l));
vsum = _mm256_add_ps(vsum, _mm256_mul_ps(vx, vy));
}
now += Floatsum(vsum);
#endif
for (; l < q2; l++) {
now += qd[i * q2 + l] * kd[j * q2 + l];
}
qk[j] = sum * scale;
maxValue = std::max(maxValue, sum * scale);
qk[j] = now * scale;
maxValue = std::max(maxValue, now * scale);
}
for (int j = 0; j < k1; j++) {

int j = 0;
#ifdef __aarch64__
float32x4_t vmax = vdupq_n_f32(maxValue);
for (; j + 3 < k1; j += 4) {
vst1q_f32(temp + j, exp_ps(vsubq_f32(vld1q_f32(qk + j), vmax)));
}
#endif
for (; j < k1; j++) {
temp[j] = expf(qk[j] - maxValue);
}

sum = 0.0f;
for (int j = 0; j < k1; j++) {
sum += temp[j];
}
sum = std::max(sum, 0.1f);
Expand Down

0 comments on commit 932dee0

Please sign in to comment.