mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml : optimize ggml_vec_dot_q4_0_q8_0() using vectorized accumulators
This commit is contained in:
		
							
								
								
									
										39
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -2766,8 +2766,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * | |||||||
|     float sumf = 0.0; |     float sumf = 0.0; | ||||||
|  |  | ||||||
| #if defined(__ARM_NEON) | #if defined(__ARM_NEON) | ||||||
|     float sum0 = 0.0f; |     float32x4_t sumv0 = vdupq_n_f32(0.0f); | ||||||
|     float sum1 = 0.0f; |     float32x4_t sumv1 = vdupq_n_f32(0.0f); | ||||||
|  |  | ||||||
|     for (int i = 0; i < nb; i += 2) { |     for (int i = 0; i < nb; i += 2) { | ||||||
|         const block_q4_0 * restrict x0 = &x[i + 0]; |         const block_q4_0 * restrict x0 = &x[i + 0]; | ||||||
| @@ -2807,14 +2807,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * | |||||||
|  |  | ||||||
| #if defined(__ARM_FEATURE_DOTPROD) | #if defined(__ARM_FEATURE_DOTPROD) | ||||||
|         // dot product into int32x4_t |         // dot product into int32x4_t | ||||||
|         int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls); |         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); | ||||||
|         int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls); |         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); | ||||||
|  |  | ||||||
|         p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs); | #if 0 | ||||||
|         p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs); |         // note: this is faster for 4-6 threads by slower for more threads | ||||||
|  |         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); | ||||||
|         sum0 += x0->d*y0->d*vaddvq_s32(p_0); |         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); | ||||||
|         sum1 += x1->d*y1->d*vaddvq_s32(p_1); | #else | ||||||
|  |         sumv0 = vaddq_f32(sumv0, vmulq_f32(vcvtq_f32_s32(p_0), vdupq_n_f32(x0->d*y0->d))); | ||||||
|  |         sumv1 = vaddq_f32(sumv1, vmulq_f32(vcvtq_f32_s32(p_1), vdupq_n_f32(x1->d*y1->d))); | ||||||
|  | #endif | ||||||
| #else | #else | ||||||
|         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); |         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); | ||||||
|         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); |         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); | ||||||
| @@ -2826,21 +2829,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * | |||||||
|         const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); |         const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); | ||||||
|         const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); |         const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); | ||||||
|  |  | ||||||
|         const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); |         const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); | ||||||
|         const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); |         const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); | ||||||
|  |         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); | ||||||
|  |         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); | ||||||
|  |  | ||||||
|         const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); |         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); | ||||||
|         const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); |         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); | ||||||
|  |  | ||||||
|         const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); |  | ||||||
|         const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); |  | ||||||
|  |  | ||||||
|         sum0 += x0->d*y0->d*vaddvq_s16(p_0); |  | ||||||
|         sum1 += x1->d*y1->d*vaddvq_s16(p_1); |  | ||||||
| #endif | #endif | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     sumf = sum0 + sum1; |     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); | ||||||
| #elif defined(__AVX2__) | #elif defined(__AVX2__) | ||||||
|     // Initialize accumulator with zeros |     // Initialize accumulator with zeros | ||||||
|     __m256 acc = _mm256_setzero_ps(); |     __m256 acc = _mm256_setzero_ps(); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov