mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml : use 8-bit precision for Q4_1 intermediate results (#1047)
* ggml : use 8-bit precision for Q4_1 intermediate results (ARM) * ggml : optimize ggml_vec_dot_q4_1_q8_0() via vmalq_n_f32 56 ms/token with Q4_1 ! * ggml : AVX2 implementation of ggml_vec_dot_q4_1_q8_0 (#1051) * gitignore : ignore ppl-*.txt files --------- Co-authored-by: slaren <2141330+slaren@users.noreply.github.com>
This commit is contained in:
		
							
								
								
									
										15
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,11 +1,15 @@ | ||||
| *.o | ||||
| *.a | ||||
| .DS_Store | ||||
| .build/ | ||||
| .cache/ | ||||
| .direnv/ | ||||
| .envrc | ||||
| .swiftpm | ||||
| .venv | ||||
| .vs/ | ||||
| .vscode/ | ||||
| .DS_Store | ||||
|  | ||||
| .build/ | ||||
| build/ | ||||
| build-em/ | ||||
| build-debug/ | ||||
| @@ -30,12 +34,9 @@ models/* | ||||
| arm_neon.h | ||||
| compile_commands.json | ||||
|  | ||||
| .envrc | ||||
| .direnv/ | ||||
|  | ||||
| .venv | ||||
| __pycache__ | ||||
| .swiftpm | ||||
|  | ||||
| zig-out/ | ||||
| zig-cache/ | ||||
|  | ||||
| ppl-*.txt | ||||
|   | ||||
							
								
								
									
										371
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										371
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -550,6 +550,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) { | ||||
|         (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15); | ||||
| } | ||||
|  | ||||
| inline static int16_t vaddvq_s8(int8x16_t v) { | ||||
|     return | ||||
|         (int16_t)vgetq_lane_s8(v, 0)  + (int16_t)vgetq_lane_s8(v, 1)  + | ||||
|         (int16_t)vgetq_lane_s8(v, 2)  + (int16_t)vgetq_lane_s8(v, 3)  + | ||||
|         (int16_t)vgetq_lane_s8(v, 4)  + (int16_t)vgetq_lane_s8(v, 5)  + | ||||
|         (int16_t)vgetq_lane_s8(v, 6)  + (int16_t)vgetq_lane_s8(v, 7)  + | ||||
|         (int16_t)vgetq_lane_s8(v, 8)  + (int16_t)vgetq_lane_s8(v, 9)  + | ||||
|         (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) + | ||||
|         (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) + | ||||
|         (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15); | ||||
| } | ||||
|  | ||||
| inline static int32_t vaddvq_s16(int16x8_t v) { | ||||
|     return | ||||
|         (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + | ||||
| @@ -1535,9 +1547,8 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); | ||||
| static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); | ||||
| //static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); | ||||
| static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); | ||||
| static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); | ||||
|  | ||||
| static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { | ||||
| @@ -1552,8 +1563,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { | ||||
|         .dequantize_row_q         = dequantize_row_q4_1, | ||||
|         .quantize_row_q           = quantize_row_q4_1, | ||||
|         .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, | ||||
|         .quantize_row_q_dot       = quantize_row_q4_1, | ||||
|         .vec_dot_q                = ggml_vec_dot_q4_1, | ||||
|         .quantize_row_q_dot       = quantize_row_q8_0, | ||||
|         .vec_dot_q                = ggml_vec_dot_q4_1_q8_0, | ||||
|     }, | ||||
|     [GGML_TYPE_Q4_2] = { | ||||
|         .dequantize_row_q         = dequantize_row_q4_2, | ||||
| @@ -2170,189 +2181,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t | ||||
|     *s = sumf; | ||||
| } | ||||
|  | ||||
| static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { | ||||
|     const int nb = n / QK4_1; | ||||
|  | ||||
|     const block_q4_1 * restrict x = vx; | ||||
|     const block_q4_1 * restrict y = vy; | ||||
|  | ||||
|     float sumf = 0.0; | ||||
|  | ||||
| #if defined(__AVX2__) | ||||
|     // Initialize accumulator with zeros | ||||
|     __m256 acc = _mm256_setzero_ps(); | ||||
|     // Accumulator for constant offsets | ||||
|     float acc_offset = 0.0f; | ||||
|  | ||||
|     // Main loop | ||||
|     for (int i = 0; i < nb; ++i) { | ||||
|         const float * d0 = &x[i].d; | ||||
|         const float * d1 = &y[i].d; | ||||
|  | ||||
|         const float * m0 = &x[i].m; | ||||
|         const float * m1 = &y[i].m; | ||||
|  | ||||
|         const __m256 d0v = _mm256_broadcast_ss( d0 ); | ||||
|         const __m256 d1v = _mm256_broadcast_ss( d1 ); | ||||
|         const __m256 m0v = _mm256_broadcast_ss( m0 ); | ||||
|         const __m256 m1v = _mm256_broadcast_ss( m1 ); | ||||
|  | ||||
|         // Compute combined scale for the block | ||||
|         const __m256 scale_01 = _mm256_mul_ps( d0v, d1v ); | ||||
|  | ||||
|         // Compute cross scales for the block | ||||
|         const __m256 scale_0 = _mm256_mul_ps( d0v, m1v ); | ||||
|         const __m256 scale_1 = _mm256_mul_ps( m0v, d1v ); | ||||
|         const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ ); | ||||
|  | ||||
|         // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes | ||||
|         __m256i bx = bytesFromNibbles( x[i].qs ); | ||||
|         __m256i by = bytesFromNibbles( y[i].qs ); | ||||
|  | ||||
|         // Now we have a vector with bytes in [ 0 .. 15 ] interval. | ||||
|  | ||||
|         // Sign-extend first 16 signed bytes into int16_t | ||||
|         __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); | ||||
|         __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); | ||||
|         // Compute products of int16_t integers, add pairwise | ||||
|         __m256i i32 = _mm256_madd_epi16( x16, y16 ); | ||||
|  | ||||
|         // Sign-extend last 16 signed bytes into int16_t vectors | ||||
|         __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); | ||||
|         __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); | ||||
|         // Accumulate products of int16_t integers | ||||
|         i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); | ||||
|  | ||||
|         // compute sums of unsigned bytes in bx, by in blocks of 8. | ||||
|         // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, | ||||
|         // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. | ||||
|         // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] | ||||
|         __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); | ||||
|         __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); | ||||
|         __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); | ||||
|         __m256  sums  = _mm256_cvtepi32_ps( sumsi ); | ||||
|  | ||||
|         // Convert int32_t to float | ||||
|         __m256 p = _mm256_cvtepi32_ps( i32 ); | ||||
|         // Apply the scale, and accumulate | ||||
|         // acc += d0*d1*x*y + d0*m1*x + d1*m0*y | ||||
|         acc = _mm256_fmadd_ps( scale_01, p, acc ); | ||||
|         acc = _mm256_fmadd_ps( cross_scales, sums, acc ); | ||||
|         // acc_offset += m0*m1 (for each entry in the block) | ||||
|         acc_offset += (*m0)*(*m1); | ||||
|     } | ||||
|  | ||||
|     // Return horizontal sum of the acc vector | ||||
|     __m128 res = _mm256_extractf128_ps( acc, 1 ); | ||||
|     res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); | ||||
|     res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); | ||||
|     res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); | ||||
|  | ||||
|     sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1; | ||||
| #elif defined(__ARM_NEON) | ||||
|     float sum00 = 0.0f; | ||||
|     float sum01 = 0.0f; | ||||
|     float sum10 = 0.0f; | ||||
|     float sum11 = 0.0f; | ||||
|  | ||||
|     for (int i = 0; i < nb; i += 2) { | ||||
|         const block_q4_1 * restrict x0 = &x[i + 0]; | ||||
|         const block_q4_1 * restrict y0 = &y[i + 0]; | ||||
|         const block_q4_1 * restrict x1 = &x[i + 1]; | ||||
|         const block_q4_1 * restrict y1 = &y[i + 1]; | ||||
|  | ||||
|         const uint8x16_t m4b = vdupq_n_u8(0xf); | ||||
|  | ||||
|         const uint8x16_t v0_0 = vld1q_u8(x0->qs); | ||||
|         const uint8x16_t v1_0 = vld1q_u8(y0->qs); | ||||
|         const uint8x16_t v0_1 = vld1q_u8(x1->qs); | ||||
|         const uint8x16_t v1_1 = vld1q_u8(y1->qs); | ||||
|  | ||||
|         // 4-bit -> 8-bit | ||||
|         const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); | ||||
|         const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); | ||||
|         const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); | ||||
|         const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); | ||||
|  | ||||
|         const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); | ||||
|         const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); | ||||
|         const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); | ||||
|         const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); | ||||
|  | ||||
|         sum00 += x0->m*y0->m; | ||||
|         sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h)); | ||||
|         sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h)); | ||||
|  | ||||
|         sum00 += x1->m*y1->m; | ||||
|         sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h)); | ||||
|         sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h)); | ||||
|  | ||||
| #if defined(__ARM_FEATURE_DOTPROD) | ||||
|         // dot product into int32x4_t | ||||
|         uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l); | ||||
|         uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l); | ||||
|  | ||||
|         p_0 = vdotq_u32(p_0, v0_0h, v1_0h); | ||||
|         p_1 = vdotq_u32(p_1, v0_1h, v1_1h); | ||||
|  | ||||
|         sum11 += x0->d*y0->d*vaddvq_u32(p_0); | ||||
|         sum11 += x1->d*y1->d*vaddvq_u32(p_1); | ||||
| #else | ||||
|         const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); | ||||
|         const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); | ||||
|         const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); | ||||
|         const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); | ||||
|  | ||||
|         const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); | ||||
|         const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); | ||||
|         const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); | ||||
|         const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); | ||||
|  | ||||
|         const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h); | ||||
|         const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h); | ||||
|  | ||||
|         const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h); | ||||
|         const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h); | ||||
|  | ||||
|         const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0); | ||||
|         const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1); | ||||
|  | ||||
|         sum11 += x0->d*y0->d*vaddvq_u16(p_0); | ||||
|         sum11 += x1->d*y1->d*vaddvq_u16(p_1); | ||||
| #endif | ||||
|     } | ||||
|  | ||||
|     sumf = QK4_1*sum00 + sum01 + sum10 + sum11; | ||||
| #else | ||||
|     // scalar | ||||
|     for (int i = 0; i < nb; i++) { | ||||
|         const float d0 = x[i].d; | ||||
|         const float d1 = y[i].d; | ||||
|  | ||||
|         const float m0 = x[i].m; | ||||
|         const float m1 = y[i].m; | ||||
|  | ||||
|         const uint8_t * restrict p0 = x[i].qs; | ||||
|         const uint8_t * restrict p1 = y[i].qs; | ||||
|  | ||||
|         for (int j = 0; j < QK4_1/2; j++) { | ||||
|             const uint8_t v0 = p0[j]; | ||||
|             const uint8_t v1 = p1[j]; | ||||
|  | ||||
|             const float f0 = d0*(v0 & 0xf) + m0; | ||||
|             const float f1 = d0*(v0 >> 4)  + m0; | ||||
|  | ||||
|             const float f2 = d1*(v1 & 0xf) + m1; | ||||
|             const float f3 = d1*(v1 >> 4)  + m1; | ||||
|  | ||||
|             sumf += f0*f2 + f1*f3; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     *s = sumf; | ||||
| } | ||||
|  | ||||
| static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { | ||||
|     const int nb = n / QK8_0; | ||||
|  | ||||
| @@ -2549,6 +2377,175 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * | ||||
|     *s = sumf; | ||||
| } | ||||
|  | ||||
| static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { | ||||
|     const int nb = n / QK8_0; | ||||
|  | ||||
|     assert(n % QK8_0 == 0); | ||||
|     assert(nb % 2 == 0); | ||||
|  | ||||
|     const block_q4_1 * restrict x = vx; | ||||
|     const block_q8_0 * restrict y = vy; | ||||
|  | ||||
|     float sumf = 0.0; | ||||
|  | ||||
|     // TODO: add AVX / WASM SIMD / etc | ||||
| #if defined(__ARM_NEON) | ||||
|     float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||||
|     float32x4_t sumv1 = vdupq_n_f32(0.0f); | ||||
|  | ||||
|     for (int i = 0; i < nb; i += 2) { | ||||
|         const block_q4_1 * restrict x0 = &x[i + 0]; | ||||
|         const block_q4_1 * restrict x1 = &x[i + 1]; | ||||
|         const block_q8_0 * restrict y0 = &y[i + 0]; | ||||
|         const block_q8_0 * restrict y1 = &y[i + 1]; | ||||
|  | ||||
|         const uint8x16_t m4b = vdupq_n_u8(0xf); | ||||
|  | ||||
|         const uint8x16_t v0_0 = vld1q_u8(x0->qs); | ||||
|         const uint8x16_t v0_1 = vld1q_u8(x1->qs); | ||||
|  | ||||
|         // 4-bit -> 8-bit | ||||
|         const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b)); | ||||
|         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); | ||||
|         const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b)); | ||||
|         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); | ||||
|  | ||||
|         // load y | ||||
|         const int8x16_t v1_0l = vld1q_s8(y0->qs); | ||||
|         const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); | ||||
|         const int8x16_t v1_1l = vld1q_s8(y1->qs); | ||||
|         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); | ||||
|  | ||||
|         // interleave | ||||
|         const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h); | ||||
|         const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h); | ||||
|         const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); | ||||
|         const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); | ||||
|  | ||||
|         const int16x8_t s0i = vaddq_s16( | ||||
|                         vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), | ||||
|                         vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); | ||||
|  | ||||
|         const int16x8_t s1i = vaddq_s16( | ||||
|                         vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), | ||||
|                         vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); | ||||
|  | ||||
|         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); | ||||
|         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); | ||||
|  | ||||
| #if defined(__ARM_FEATURE_DOTPROD) | ||||
|         // dot product into int32x4_t | ||||
|         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); | ||||
|         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); | ||||
|  | ||||
|         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); | ||||
|         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); | ||||
| #else | ||||
|         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); | ||||
|         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); | ||||
|         const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); | ||||
|         const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); | ||||
|  | ||||
|         const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); | ||||
|         const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); | ||||
|         const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); | ||||
|         const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); | ||||
|  | ||||
|         const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); | ||||
|         const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); | ||||
|         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); | ||||
|         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); | ||||
|  | ||||
|         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); | ||||
|         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); | ||||
| #endif | ||||
|     } | ||||
|  | ||||
|     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); | ||||
| #elif defined(__AVX2__) | ||||
|     // Initialize accumulator with zeros | ||||
|     __m256 acc = _mm256_setzero_ps(); | ||||
|  | ||||
|     // Main loop | ||||
|     for (int i = 0; i < nb; ++i) { | ||||
|         const float * d0 = &x[i].d; | ||||
|         const float * d1 = &y[i].d; | ||||
|         const float * m0 = &x[i].m; | ||||
|  | ||||
|         const __m256 d0v = _mm256_broadcast_ss( d0 ); | ||||
|         const __m256 d1v = _mm256_broadcast_ss( d1 ); | ||||
|         const __m256 m0v = _mm256_broadcast_ss( m0 ); | ||||
|  | ||||
|         // Compute combined scales | ||||
|         const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); | ||||
|         const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); | ||||
|  | ||||
|         // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes | ||||
|         const __m256i bx = bytesFromNibbles( x[i].qs ); | ||||
|         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); | ||||
|  | ||||
|         // Get absolute values of x vectors | ||||
|         const __m256i ax = _mm256_sign_epi8( bx, bx ); | ||||
|  | ||||
|         // Sign the values of the y vectors | ||||
|         const __m256i sy = _mm256_sign_epi8( by, bx ); | ||||
|  | ||||
|         // Perform multiplication and create 16-bit values | ||||
|         const __m256i dot = _mm256_maddubs_epi16( ax, sy ); | ||||
|         const __m256i ones = _mm256_set1_epi16( 1 ); | ||||
|         const __m256i xy_q = _mm256_madd_epi16( ones, dot ); | ||||
|  | ||||
|         // Convert to vector of 8 int32_t to 8 floats | ||||
|         const __m256 xy = _mm256_cvtepi32_ps( xy_q ); | ||||
|  | ||||
|         // Accumulate d0*d1*x*y | ||||
|         acc = _mm256_fmadd_ps( d0d1, xy, acc ); | ||||
|  | ||||
|         // Compute sum of y values | ||||
|         const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); | ||||
|         const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); | ||||
|         const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); | ||||
|         const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); | ||||
|  | ||||
|         // Accumulate d1*m0*y | ||||
|         acc = _mm256_fmadd_ps( d1m0, ysum, acc ); | ||||
|     } | ||||
|  | ||||
|     // Return horizontal sum of the acc vector | ||||
|     __m128 res = _mm256_extractf128_ps( acc, 1 ); | ||||
|     res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); | ||||
|     res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); | ||||
|     res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); | ||||
|  | ||||
|     sumf = _mm_cvtss_f32( res ); | ||||
| #else | ||||
|     // scalar | ||||
|     for (int i = 0; i < nb; i++) { | ||||
|         const float d0 = x[i].d; | ||||
|         const float m0 = x[i].m; | ||||
|         const float d1 = y[i].d; | ||||
|  | ||||
|         const uint8_t * restrict p0 = x[i].qs; | ||||
|         const  int8_t * restrict p1 = y[i].qs; | ||||
|  | ||||
|         // TODO: this is very slow .. | ||||
|         for (int j = 0; j < QK8_0/2; j++) { | ||||
|             const uint8_t v0 = p0[j]; | ||||
|  | ||||
|             const float f0 = d0*(v0 & 0xf) + m0; | ||||
|             const float f1 = d0*(v0 >> 4)  + m0; | ||||
|  | ||||
|             const float f2 = d1*p1[2*j + 0]; | ||||
|             const float f3 = d1*p1[2*j + 1]; | ||||
|  | ||||
|             sumf += f0*f2 + f1*f3; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     *s = sumf; | ||||
| } | ||||
|  | ||||
| static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { | ||||
|     const int nb = n / QK8_0; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov