ggml-cpu : optimize RVV q2_k and q3_k kernels (#16887)

This commit is contained in:
xctan
2025-11-07 00:12:45 +08:00
committed by GitHub
parent aa374175c3
commit 7f09a680af

View File

@@ -580,16 +580,19 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
uint8_t *patmp = atmp; uint8_t *patmp = atmp;
int vsums; int vsums;
int tmp; int tmp, t1, t2, t3, t4, t5, t6, t7;
__asm__ __volatile__( __asm__ __volatile__(
"vsetivli zero, 16, e8, m1\n\t" "vsetivli zero, 16, e8, m1\n\t"
"vmv.v.x v8, zero\n\t" "vmv.v.x v8, zero\n\t"
"lb zero, 15(%[sc])\n\t"
"vle8.v v1, (%[sc])\n\t" "vle8.v v1, (%[sc])\n\t"
"vle8.v v2, (%[bsums])\n\t"
"addi %[tmp], %[bsums], 16\n\t"
"vand.vi v0, v1, 0xF\n\t" "vand.vi v0, v1, 0xF\n\t"
"vsrl.vi v1, v1, 4\n\t" "vsrl.vi v1, v1, 4\n\t"
"vle8.v v3, (%[tmp])\n\t"
"vse8.v v0, (%[scale])\n\t" "vse8.v v0, (%[scale])\n\t"
"vsetivli zero, 16, e16, m2\n\t" "vsetivli zero, 16, e16, m2\n\t"
"vle16.v v2, (%[bsums])\n\t"
"vzext.vf2 v0, v1\n\t" "vzext.vf2 v0, v1\n\t"
"vwmul.vv v4, v0, v2\n\t" "vwmul.vv v4, v0, v2\n\t"
"vsetivli zero, 16, e32, m4\n\t" "vsetivli zero, 16, e32, m4\n\t"
@@ -608,46 +611,89 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
__asm__ __volatile__( __asm__ __volatile__(
"vsetvli zero, %[vl32], e8, m2\n\t" "lb zero, 31(%[q2])\n\t"
"addi %[tmp], %[q2], 16\n\t"
"addi %[t1], %[q8], 16\n\t"
"vsetivli zero, 16, e8, m1\n\t"
"vle8.v v0, (%[q2])\n\t" "vle8.v v0, (%[q2])\n\t"
"vle8.v v1, (%[tmp])\n\t"
"vsrl.vi v2, v0, 2\n\t" "vsrl.vi v2, v0, 2\n\t"
"vsrl.vi v3, v1, 2\n\t"
"vsrl.vi v4, v0, 4\n\t" "vsrl.vi v4, v0, 4\n\t"
"vsrl.vi v6, v0, 6\n\t" "addi %[tmp], %[q8], 32\n\t"
"vand.vi v0, v0, 0x3\n\t"
"vand.vi v2, v2, 0x3\n\t"
"vand.vi v4, v4, 0x3\n\t"
"vsetvli zero, %[vl128], e8, m8\n\t"
"vle8.v v8, (%[q8])\n\t" "vle8.v v8, (%[q8])\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t" "vle8.v v9, (%[t1])\n\t"
"addi %[t1], %[t1], 32\n\t"
"vsrl.vi v5, v1, 4\n\t"
"vsrl.vi v6, v0, 6\n\t"
"vsrl.vi v7, v1, 6\n\t"
"vle8.v v10, (%[tmp])\n\t"
"vle8.v v11, (%[t1])\n\t"
"addi %[tmp], %[tmp], 32\n\t"
"addi %[t1], %[t1], 32\n\t"
"vand.vi v0, v0, 0x3\n\t"
"vand.vi v1, v1, 0x3\n\t"
"vand.vi v2, v2, 0x3\n\t"
"vle8.v v12, (%[tmp])\n\t"
"vle8.v v13, (%[t1])\n\t"
"addi %[tmp], %[tmp], 32\n\t"
"addi %[t1], %[t1], 32\n\t"
"vand.vi v3, v3, 0x3\n\t"
"vand.vi v4, v4, 0x3\n\t"
"vand.vi v5, v5, 0x3\n\t"
"vle8.v v14, (%[tmp])\n\t"
"vle8.v v15, (%[t1])\n\t"
"vwmul.vv v16, v0, v8\n\t" "vwmul.vv v16, v0, v8\n\t"
"vwmul.vv v18, v1, v9\n\t"
"vwmul.vv v20, v2, v10\n\t"
"vwmul.vv v22, v3, v11\n\t"
"vwmul.vv v24, v4, v12\n\t" "vwmul.vv v24, v4, v12\n\t"
"vsetivli zero, 16, e16, m2\n\t" "vwmul.vv v26, v5, v13\n\t"
"vwmul.vv v28, v6, v14\n\t"
"vwmul.vv v30, v7, v15\n\t"
"vsetivli zero, 8, e16, m1\n\t"
"vmv.v.x v0, zero\n\t" "vmv.v.x v0, zero\n\t"
"vwredsum.vs v10, v16, v0\n\t" "lbu %[tmp], 0(%[scale])\n\t"
"vwredsum.vs v8, v16, v0\n\t"
"vwredsum.vs v9, v18, v0\n\t" "vwredsum.vs v9, v18, v0\n\t"
"vwredsum.vs v8, v20, v0\n\t" "lbu %[t1], 1(%[scale])\n\t"
"vwredsum.vs v7, v22, v0\n\t" "vwredsum.vs v10, v20, v0\n\t"
"vwredsum.vs v11, v24, v0\n\t" "vwredsum.vs v11, v22, v0\n\t"
"vwredsum.vs v12, v26, v0\n\t" "lbu %[t2], 2(%[scale])\n\t"
"vwredsum.vs v13, v28, v0\n\t" "vwredsum.vs v12, v24, v0\n\t"
"vwredsum.vs v14, v30, v0\n\t" "vwredsum.vs v13, v26, v0\n\t"
"lbu %[t3], 3(%[scale])\n\t"
"vwredsum.vs v14, v28, v0\n\t"
"vwredsum.vs v15, v30, v0\n\t"
"lbu %[t4], 4(%[scale])\n\t"
"vwredsum.vs v8, v17, v8\n\t"
"vwredsum.vs v9, v19, v9\n\t"
"lbu %[t5], 5(%[scale])\n\t"
"vwredsum.vs v10, v21, v10\n\t"
"vwredsum.vs v11, v23, v11\n\t"
"lbu %[t6], 6(%[scale])\n\t"
"vwredsum.vs v12, v25, v12\n\t"
"vwredsum.vs v13, v27, v13\n\t"
"lbu %[t7], 7(%[scale])\n\t"
"vwredsum.vs v14, v29, v14\n\t"
"vwredsum.vs v15, v31, v15\n\t"
"vsetivli zero, 4, e32, m1\n\t" "vsetivli zero, 4, e32, m1\n\t"
"vslideup.vi v10, v9, 1\n\t" "vmul.vx v0, v8, %[tmp]\n\t"
"vslideup.vi v8, v7, 1\n\t" "vmul.vx v1, v9, %[t1]\n\t"
"vslideup.vi v11, v12, 1\n\t" "vmacc.vx v0, %[t2], v10\n\t"
"vslideup.vi v13, v14, 1\n\t" "vmacc.vx v1, %[t3], v11\n\t"
"vslideup.vi v10, v8, 2\n\t" "vmacc.vx v0, %[t4], v12\n\t"
"vslideup.vi v11, v13, 2\n\t" "vmacc.vx v1, %[t5], v13\n\t"
"vsetivli zero, 8, e32, m2\n\t" "vmacc.vx v0, %[t6], v14\n\t"
"vle8.v v15, (%[scale])\n\t" "vmacc.vx v1, %[t7], v15\n\t"
"vzext.vf4 v12, v15\n\t"
"vmul.vv v10, v10, v12\n\t"
"vredsum.vs v0, v10, v0\n\t"
"vmv.x.s %[tmp], v0\n\t" "vmv.x.s %[tmp], v0\n\t"
"add %[isum], %[isum], %[tmp]" "vmv.x.s %[t1], v1\n\t"
: [tmp] "=&r" (tmp), [isum] "+&r" (isum) "add %[isum], %[isum], %[tmp]\n\t"
"add %[isum], %[isum], %[t1]"
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
, [isum] "+&r" (isum)
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
: "memory" : "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
@@ -929,7 +975,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
int8_t * scale = (int8_t *)utmp; int8_t * scale = (int8_t *)utmp;
int tmp; int tmp, t1, t2, t3, t4, t5, t6, t7;
__asm__ __volatile__( __asm__ __volatile__(
"vsetivli zero, 12, e8, m1\n\t" "vsetivli zero, 12, e8, m1\n\t"
"vle8.v v0, (%[s6b])\n\t" "vle8.v v0, (%[s6b])\n\t"
@@ -967,19 +1013,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
int isum = 0; int isum = 0;
for (int j = 0; j < QK_K; j += 128) { for (int j = 0; j < QK_K; j += 128) {
__asm__ __volatile__( __asm__ __volatile__(
"lb zero, 31(%[q3])\n\t"
"vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
"vle8.v v8, (%[q3])\n\t" "vle8.v v8, (%[q3])\n\t"
"vsrl.vi v10, v8, 2\n\t" "vsrl.vi v10, v8, 2\n\t"
"vsrl.vi v12, v8, 4\n\t" "vsrl.vi v12, v8, 4\n\t"
"vsrl.vi v14, v8, 6\n\t" "vsrl.vi v14, v8, 6\n\t"
"lb zero, 64(%[q8])\n\t"
"vand.vi v8, v8, 3\n\t" "vand.vi v8, v8, 3\n\t"
"vand.vi v10, v10, 3\n\t" "vand.vi v10, v10, 3\n\t"
"vand.vi v12, v12, 3\n\t" "vand.vi v12, v12, 3\n\t"
"vle8.v v2, (%[qh])\n\t" "vle8.v v2, (%[qh])\n\t"
"lb zero, 127(%[q8])\n\t"
"vand.vx v4, v2, %[m]\n\t" "vand.vx v4, v2, %[m]\n\t"
"slli %[m], %[m], 1\n\t" "slli %[m], %[m], 1\n\t"
"vmseq.vx v0, v4, zero\n\t" "vmseq.vx v0, v4, zero\n\t"
"vadd.vi v8, v8, -4, v0.t\n\t" "vadd.vi v8, v8, -4, v0.t\n\t"
"lb zero, 0(%[q8])\n\t"
"vand.vx v4, v2, %[m]\n\t" "vand.vx v4, v2, %[m]\n\t"
"slli %[m], %[m], 1\n\t" "slli %[m], %[m], 1\n\t"
"vmseq.vx v0, v4, zero\n\t" "vmseq.vx v0, v4, zero\n\t"
@@ -994,34 +1044,43 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
"vadd.vi v14, v14, -4, v0.t\n\t" "vadd.vi v14, v14, -4, v0.t\n\t"
"vsetvli zero, %[vl128], e8, m8\n\t" "vsetvli zero, %[vl128], e8, m8\n\t"
"vle8.v v0, (%[q8])\n\t" "vle8.v v0, (%[q8])\n\t"
"lb %[tmp], 0(%[scale])\n\t"
"lb %[t1], 1(%[scale])\n\t"
"lb %[t2], 2(%[scale])\n\t"
"lb %[t3], 3(%[scale])\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t" "vsetvli zero, %[vl64], e8, m4\n\t"
"vwmul.vv v16, v0, v8\n\t" "vwmul.vv v16, v0, v8\n\t"
"vwmul.vv v24, v4, v12\n\t" "vwmul.vv v24, v4, v12\n\t"
"vsetivli zero, 16, e16, m2\n\t" "vsetivli zero, 16, e16, m2\n\t"
"vmv.v.x v0, zero\n\t" "vmv.v.x v0, zero\n\t"
"vwredsum.vs v10, v16, v0\n\t" "vwredsum.vs v8, v16, v0\n\t"
"lb %[t4], 4(%[scale])\n\t"
"lb %[t5], 5(%[scale])\n\t"
"vwredsum.vs v9, v18, v0\n\t" "vwredsum.vs v9, v18, v0\n\t"
"vwredsum.vs v8, v20, v0\n\t" "vwredsum.vs v10, v20, v0\n\t"
"vwredsum.vs v7, v22, v0\n\t" "vwredsum.vs v11, v22, v0\n\t"
"vwredsum.vs v11, v24, v0\n\t" "vwredsum.vs v12, v24, v0\n\t"
"vwredsum.vs v12, v26, v0\n\t" "lb %[t6], 6(%[scale])\n\t"
"vwredsum.vs v13, v28, v0\n\t" "lb %[t7], 7(%[scale])\n\t"
"vwredsum.vs v14, v30, v0\n\t" "vwredsum.vs v13, v26, v0\n\t"
"vwredsum.vs v14, v28, v0\n\t"
"vwredsum.vs v15, v30, v0\n\t"
"vsetivli zero, 4, e32, m1\n\t" "vsetivli zero, 4, e32, m1\n\t"
"vslideup.vi v10, v9, 1\n\t" "vmul.vx v0, v8, %[tmp]\n\t"
"vslideup.vi v8, v7, 1\n\t" "vmul.vx v1, v9, %[t1]\n\t"
"vslideup.vi v11, v12, 1\n\t" "vmacc.vx v0, %[t2], v10\n\t"
"vslideup.vi v13, v14, 1\n\t" "vmacc.vx v1, %[t3], v11\n\t"
"vslideup.vi v10, v8, 2\n\t" "vmacc.vx v0, %[t4], v12\n\t"
"vslideup.vi v11, v13, 2\n\t" "vmacc.vx v1, %[t5], v13\n\t"
"vsetivli zero, 8, e32, m2\n\t" "vmacc.vx v0, %[t6], v14\n\t"
"vle8.v v15, (%[scale])\n\t" "vmacc.vx v1, %[t7], v15\n\t"
"vsext.vf4 v12, v15\n\t"
"vmul.vv v10, v10, v12\n\t"
"vredsum.vs v0, v10, v0\n\t"
"vmv.x.s %[tmp], v0\n\t" "vmv.x.s %[tmp], v0\n\t"
"add %[isum], %[isum], %[tmp]" "vmv.x.s %[t1], v1\n\t"
: [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) "add %[isum], %[isum], %[tmp]\n\t"
"add %[isum], %[isum], %[t1]"
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
, [m] "+&r" (m), [isum] "+&r" (isum)
: [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
, [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
: "memory" : "memory"