mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	arm64: optimize q6_k_q8_k kernel with i8mm (#13519)
This PR improves q6_k_q8_k gemm kernel with arm64 i8mm instruction.
Tested on neoverse-n2 with llama3 8b q6_k quantization model.
- 40% ~ 54% S_PP uplift for all batch sizes
- 16% ~ 47% S_TG uplift for batch size 4 and above
Perplexity doesn't change with this PR.
```
// tested on neoverse-n2
$ llama-batched-bench \
      -m Meta-Llama-3-8B-Instruct-Q6_K.gguf \
      --no-mmap -fa \
      -c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \
      -npl 1,2,4,8,16,32 \
      -t 64
---------------------------------------------------------------------
|    PP |     TG |    B |       S_PP t/s      |       S_TG t/s      |
|       |        |      | original |  this pr | original |  this pr |
|-------|--------|------|----------|----------|----------|----------|
|   128 |    128 |    1 |    78.52 |   109.18 |    18.63 |    18.88 |
|   128 |    128 |    2 |    84.62 |   123.94 |    34.54 |    36.92 |
|   128 |    128 |    4 |    84.36 |   122.49 |    52.65 |    61.32 |
|   128 |    128 |    8 |    90.52 |   138.87 |    63.46 |    84.41 |
|   128 |    128 |   16 |    90.11 |   138.56 |    71.04 |   101.33 |
|   128 |    128 |   32 |    89.81 |   137.79 |    75.14 |   110.47 |
---------------------------------------------------------------------
```
			
			
This commit is contained in:
		| @@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi | ||||
|  | ||||
| void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { | ||||
|     assert(n % QK_K == 0); | ||||
| #ifdef __ARM_FEATURE_MATMUL_INT8 | ||||
|     assert((nrc == 2) || (nrc == 1)); | ||||
| #else | ||||
|     assert(nrc == 1); | ||||
| #endif | ||||
|     UNUSED(nrc); | ||||
|     UNUSED(bx); | ||||
|     UNUSED(by); | ||||
| @@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi | ||||
|  | ||||
|     const int nb = n / QK_K; | ||||
|  | ||||
| #if defined(__ARM_FEATURE_MATMUL_INT8) | ||||
|     if (nrc == 2) { | ||||
|         const block_q6_K * GGML_RESTRICT x0 = x; | ||||
|         const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx); | ||||
|         const block_q8_K * GGML_RESTRICT y0 = y; | ||||
|         const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); | ||||
|  | ||||
|         float32x4_t vfsum = vdupq_n_f32(0.0f); | ||||
|  | ||||
|         for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { | ||||
|             const uint8_t * GGML_RESTRICT ql0 = x0->ql; | ||||
|             const uint8_t * GGML_RESTRICT ql1 = x1->ql; | ||||
|             const uint8_t * GGML_RESTRICT qh0 = x0->qh; | ||||
|             const uint8_t * GGML_RESTRICT qh1 = x1->qh; | ||||
|             const  int8_t * GGML_RESTRICT qy0 = y0->qs; | ||||
|             const  int8_t * GGML_RESTRICT qy1 = y1->qs; | ||||
|  | ||||
|             const uint8x16_t mone = vdupq_n_u8(0x30); | ||||
|             const uint8x16_t  m4b = vdupq_n_u8(0x0f); | ||||
|  | ||||
|             int32x4_t visum = vdupq_n_s32(0); | ||||
|  | ||||
|             // process 8 blocks per iteration, totally 16 blocks | ||||
|             for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) { | ||||
|                 int8x16_t vx0[8], vx1[8]; | ||||
|  | ||||
|                 // de-quantize vx0[8] | ||||
|                 { | ||||
|                     const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0); | ||||
|                     const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0); | ||||
|  | ||||
|                     uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); | ||||
|                     uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); | ||||
|                     uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); | ||||
|                     uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); | ||||
|  | ||||
|                     vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); | ||||
|                     vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); | ||||
|                     vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); | ||||
|                     vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); | ||||
|  | ||||
|                     q6h_0 = vandq_u8(mone, qh_bits.val[0]); | ||||
|                     q6h_1 = vandq_u8(mone, qh_bits.val[1]); | ||||
|                     q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); | ||||
|                     q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); | ||||
|  | ||||
|                     vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); | ||||
|                     vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); | ||||
|                     vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); | ||||
|                     vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); | ||||
|                 } | ||||
|  | ||||
|                 // de-quantize vx1[8] | ||||
|                 { | ||||
|                     const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1); | ||||
|                     const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1); | ||||
|  | ||||
|                     uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); | ||||
|                     uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); | ||||
|                     uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); | ||||
|                     uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); | ||||
|  | ||||
|                     vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); | ||||
|                     vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); | ||||
|                     vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); | ||||
|                     vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); | ||||
|  | ||||
|                     q6h_0 = vandq_u8(mone, qh_bits.val[0]); | ||||
|                     q6h_1 = vandq_u8(mone, qh_bits.val[1]); | ||||
|                     q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); | ||||
|                     q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); | ||||
|  | ||||
|                     vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); | ||||
|                     vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); | ||||
|                     vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); | ||||
|                     vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); | ||||
|                 } | ||||
|  | ||||
|                 // process 16 elements (one block with same scale) per iteration | ||||
|                 // - vx = concat(ql, qh) - 32 | ||||
|                 // - r1,r2,r3,r4 = smmla(vx, vy) | ||||
|                 for (int k = 0; k < 8; ++k) { | ||||
|                     const int blk = j * 8 + k; | ||||
|  | ||||
|                     const int8x16_t vy0 = vld1q_s8(qy0); | ||||
|                     const int8x16_t vy1 = vld1q_s8(qy1); | ||||
|                     qy0 += 16; | ||||
|                     qy1 += 16; | ||||
|  | ||||
|                     const int32x4_t block_scale = { | ||||
|                         x0->scales[blk], | ||||
|                         x0->scales[blk], | ||||
|                         x1->scales[blk], | ||||
|                         x1->scales[blk], | ||||
|                     }; | ||||
|  | ||||
|                     // calculate four results at once with outer product | ||||
|                     const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); | ||||
|                     const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); | ||||
|                     const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); | ||||
|                     const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); | ||||
|                     int32x4_t vr = vdupq_n_s32(0); | ||||
|                     vr = vmmlaq_s32(vr, vx_l, vy_l); | ||||
|                     vr = vmmlaq_s32(vr, vx_h, vy_h); | ||||
|  | ||||
|                     // apply block scale, will NOT overflow | ||||
|                     // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits | ||||
|                     visum = vmlaq_s32(visum, vr, block_scale); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // adjust bias, apply superblock scale | ||||
|             { | ||||
|                 int32_t bias[4]; | ||||
| #ifdef __ARM_FEATURE_SVE | ||||
|                 const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); | ||||
|                 const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8); | ||||
|                 const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums); | ||||
|                 const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8); | ||||
|                 const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums); | ||||
|                 const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8); | ||||
|                 const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales)); | ||||
|                 const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8)); | ||||
|                 const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales)); | ||||
|                 const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8)); | ||||
|                 const svint64_t zero = svdup_n_s64(0); | ||||
|                 bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0), | ||||
|                                                                                svdot_s64(zero, y0_q8sums_1, x0_q6scales_1))); | ||||
|                 bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0), | ||||
|                                                                                svdot_s64(zero, y1_q8sums_1, x0_q6scales_1))); | ||||
|                 bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0), | ||||
|                                                                                svdot_s64(zero, y0_q8sums_1, x1_q6scales_1))); | ||||
|                 bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0), | ||||
|                                                                                svdot_s64(zero, y1_q8sums_1, x1_q6scales_1))); | ||||
| #else | ||||
|                 // NEON doesn't support int16 dot product, fallback to separated mul and add | ||||
|                 const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums); | ||||
|                 const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums); | ||||
|  | ||||
|                 int8x16_t scales_s8 = vld1q_s8(x0->scales); | ||||
|                 const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; | ||||
|                 scales_s8 = vld1q_s8(x1->scales); | ||||
|                 const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; | ||||
|  | ||||
|                 int32x4_t prod; | ||||
|                 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))), | ||||
|                                  vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1])))); | ||||
|                 bias[0] = vaddvq_s32(prod); | ||||
|                 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))), | ||||
|                                  vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1])))); | ||||
|                 bias[1] = vaddvq_s32(prod); | ||||
|                 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))), | ||||
|                                  vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1])))); | ||||
|                 bias[2] = vaddvq_s32(prod); | ||||
|                 prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))), | ||||
|                                  vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])), | ||||
|                                            vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1])))); | ||||
|                 bias[3] = vaddvq_s32(prod); | ||||
|  | ||||
| #endif | ||||
|                 const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32); | ||||
|  | ||||
|                 const float32x4_t superblock_scale = { | ||||
|                     GGML_FP16_TO_FP32(x0->d) * y0->d, | ||||
|                     GGML_FP16_TO_FP32(x0->d) * y1->d, | ||||
|                     GGML_FP16_TO_FP32(x1->d) * y0->d, | ||||
|                     GGML_FP16_TO_FP32(x1->d) * y1->d, | ||||
|                 }; | ||||
|  | ||||
|                 visum = vsubq_s32(visum, vibias); | ||||
|                 vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // vfsum = ABCD -> ACBD | ||||
|         // AC -> s, BD -> (s+bs) | ||||
|         vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); | ||||
|         vst1_f32(s,      vget_low_f32 (vfsum)); | ||||
|         vst1_f32(s + bs, vget_high_f32(vfsum)); | ||||
|  | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|  | ||||
| #ifdef __ARM_FEATURE_SVE | ||||
|     const int vector_length = ggml_cpu_get_sve_cnt()*8; | ||||
|     float sum = 0; | ||||
|   | ||||
| @@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { | ||||
|         .from_float               = quantize_row_q6_K, | ||||
|         .vec_dot                  = ggml_vec_dot_q6_K_q8_K, | ||||
|         .vec_dot_type             = GGML_TYPE_Q8_K, | ||||
| #if defined (__ARM_FEATURE_MATMUL_INT8) | ||||
|         .nrows                    = 2, | ||||
| #else | ||||
|         .nrows                    = 1, | ||||
| #endif | ||||
|     }, | ||||
|     [GGML_TYPE_IQ2_XXS] = { | ||||
|         .from_float               = NULL, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Yibo Cai
					Yibo Cai