mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml-cpu: replace NEON asm with intrinsics in ggml_gemv_q4_0_4x8_q8_0() (#10874)
* ggml-cpu: replace NEON asm with intrinsics in ggml_gemv_q4_0_4x8_q8_0() Signed-off-by: Adrien Gallouët <angt@huggingface.co> * ggml-cpu: format code Signed-off-by: Adrien Gallouët <angt@huggingface.co> --------- Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
		| @@ -564,21 +564,21 @@ static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c | |||||||
|  |  | ||||||
| #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) | #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) | ||||||
|     if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { |     if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { | ||||||
|         const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx; |         const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; | ||||||
|  |  | ||||||
|         for (int c = 0; c < nc; c += ncols_interleaved) { |         for (int c = 0; c < nc; c += ncols_interleaved) { | ||||||
|             const block_q8_0 * a_ptr = (const block_q8_0 *)vy; |             const block_q8_0 * a_ptr = (const block_q8_0 *) vy; | ||||||
|             float32x4_t acc = vdupq_n_f32(0); |             float32x4_t acc = vdupq_n_f32(0); | ||||||
|             for (int b = 0; b < nb; b++) { |             for (int b = 0; b < nb; b++) { | ||||||
|                 int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs); |                 int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); | ||||||
|                 int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16); |                 int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); | ||||||
|                 int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32); |                 int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); | ||||||
|                 int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48); |                 int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); | ||||||
|                 float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d); |                 float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); | ||||||
|  |  | ||||||
|                 int8x16_t a0 = vld1q_s8(a_ptr->qs); |                 int8x16_t a0 = vld1q_s8(a_ptr->qs); | ||||||
|                 int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2); |                 int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2); | ||||||
|                 float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d); |                 float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); | ||||||
|  |  | ||||||
|                 int32x4_t ret = vdupq_n_s32(0); |                 int32x4_t ret = vdupq_n_s32(0); | ||||||
|  |  | ||||||
| @@ -647,72 +647,52 @@ static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c | |||||||
|     UNUSED(ncols_interleaved); |     UNUSED(ncols_interleaved); | ||||||
|     UNUSED(blocklen); |     UNUSED(blocklen); | ||||||
|  |  | ||||||
| #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) | #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) | ||||||
|     if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { |     if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { | ||||||
|         const void * b_ptr = vx; |         const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; | ||||||
|         const void * a_ptr = vy; |  | ||||||
|         float * res_ptr = s; |  | ||||||
|  |  | ||||||
|         __asm__ __volatile__( |         for (int c = 0; c < nc; c += ncols_interleaved) { | ||||||
|             "movi v2.16b, #0x4\n" |             const block_q8_0 * a_ptr = (const block_q8_0 *) vy; | ||||||
|             "movi v1.16b, #0xf0\n" |             float32x4_t acc = vdupq_n_f32(0); | ||||||
|             "add %x[b_ptr], %x[b_ptr], #0x8\n" |             for (int b = 0; b < nb; b++) { | ||||||
|             "1:"  // Column loop |                 int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); | ||||||
|             "add x23, %x[a_ptr], #0x2\n" |                 int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); | ||||||
|             "movi v0.16b, #0x0\n" |                 int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); | ||||||
|             "mov x22, %x[nb]\n" |                 int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); | ||||||
|             "2:"  // Block loop |                 float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); | ||||||
|             "ldr q31, [%x[b_ptr], #0x0]\n" |  | ||||||
|             "ldr q30, [%x[b_ptr], #0x10]\n" |                 int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs); | ||||||
|             "mov x21, x23\n" |                 int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1); | ||||||
|             "movi v29.4s, #0x0\n" |                 int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2); | ||||||
|             "ldr q28, [%x[b_ptr], #0x20]\n" |                 int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3); | ||||||
|             "ldr q27, [%x[b_ptr], #0x30]\n" |                 float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); | ||||||
|             "movi v26.4s, #0x0\n" |  | ||||||
|             "sub x20, x23, #0x2\n" |                 int32x4_t ret0 = vdupq_n_s32(0); | ||||||
|             "ld1r { v25.8h }, [x20]\n" |                 int32x4_t ret1 = vdupq_n_s32(0); | ||||||
|             "ldr q24, [%x[b_ptr], #-0x8]\n" |  | ||||||
|             "sub x22, x22, #0x1\n" |                 ret0 = vdotq_s32(ret0, b0 << 4, a0); | ||||||
|             "add x23, x23, #0x22\n" |                 ret1 = vdotq_s32(ret1, b1 << 4, a0); | ||||||
|             "ld1r { v23.2d }, [x21], #0x8\n" |                 ret0 = vdotq_s32(ret0, b2 << 4, a1); | ||||||
|             "sshl v22.16b, v31.16b, v2.16b\n" |                 ret1 = vdotq_s32(ret1, b3 << 4, a1); | ||||||
|             "sshl v16.16b, v30.16b, v2.16b\n" |  | ||||||
|             "add %x[b_ptr], %x[b_ptr], #0x48\n" |                 ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2); | ||||||
|             "ld1r { v21.2d }, [x21], #0x8\n" |                 ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2); | ||||||
|             "sshl v20.16b, v28.16b, v2.16b\n" |                 ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3); | ||||||
|             "sshl v19.16b, v27.16b, v2.16b\n" |                 ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3); | ||||||
|             "ld1r { v18.2d }, [x21], #0x8\n" |  | ||||||
|             "ld1r { v17.2d }, [x21], #0x8\n" |                 int32x4_t ret = vpaddq_s32(ret0, ret1); | ||||||
|             "and v31.16b, v31.16b, v1.16b\n" |  | ||||||
|             "and v30.16b, v30.16b, v1.16b\n" |                 acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), | ||||||
|             ".inst 0x4e9796dd  // sdot v29.4s, v22.16b, v23.16b\n" |                         vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); | ||||||
|             ".inst 0x4e97961a  // sdot v26.4s, v16.16b, v23.16b\n" |                 a_ptr++; | ||||||
|             "and v28.16b, v28.16b, v1.16b\n" |                 b_ptr++; | ||||||
|             "and v27.16b, v27.16b, v1.16b\n" |             } | ||||||
|             "fcvtl v25.4s, v25.4h\n" |             vst1q_f32(s, acc); | ||||||
|             "fcvtl v16.4s, v24.4h\n" |             s += ncols_interleaved; | ||||||
|             ".inst 0x4e95969d  // sdot v29.4s, v20.16b, v21.16b\n" |         } | ||||||
|             ".inst 0x4e95967a  // sdot v26.4s, v19.16b, v21.16b\n" |  | ||||||
|             "fmul v16.4s, v16.4s, v25.4s\n" |  | ||||||
|             ".inst 0x4e9297fd  // sdot v29.4s, v31.16b, v18.16b\n" |  | ||||||
|             ".inst 0x4e9297da  // sdot v26.4s, v30.16b, v18.16b\n" |  | ||||||
|             ".inst 0x4e91979d  // sdot v29.4s, v28.16b, v17.16b\n" |  | ||||||
|             ".inst 0x4e91977a  // sdot v26.4s, v27.16b, v17.16b\n" |  | ||||||
|             "addp v29.4s, v29.4s, v26.4s\n" |  | ||||||
|             "scvtf v29.4s, v29.4s, #0x4\n" |  | ||||||
|             "fmla v0.4s, v29.4s, v16.4s\n" |  | ||||||
|             "cbnz x22, 2b\n" |  | ||||||
|             "sub %x[nc], %x[nc], #0x4\n" |  | ||||||
|             "str q0, [%x[res_ptr], #0x0]\n" |  | ||||||
|             "add %x[res_ptr], %x[res_ptr], #0x10\n" |  | ||||||
|             "cbnz %x[nc], 1b\n" |  | ||||||
|             : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) |  | ||||||
|             : [a_ptr] "r" (a_ptr), [nb] "r" (nb) |  | ||||||
|             : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" |  | ||||||
|         ); |  | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) | #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) | ||||||
|     float sumf[4]; |     float sumf[4]; | ||||||
|     int sumi; |     int sumi; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Adrien Gallouët
					Adrien Gallouët