mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	~7% faster Q5_1 AVX2 code (#1477)
This commit is contained in:
		
							
								
								
									
										39
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { | |||||||
|     return _mm256_cvtepi32_ps(summed_pairs); |     return _mm256_cvtepi32_ps(summed_pairs); | ||||||
| } | } | ||||||
|  |  | ||||||
| // multiply int8_t, add results pairwise twice and return as float vector | static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { | ||||||
| static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { |  | ||||||
|     // Get absolute values of x vectors |  | ||||||
|     const __m256i ax = _mm256_sign_epi8(x, x); |  | ||||||
|     // Sign the values of the y vectors |  | ||||||
|     const __m256i sy = _mm256_sign_epi8(y, x); |  | ||||||
| #if __AVXVNNI__ | #if __AVXVNNI__ | ||||||
|     const __m256i zero = _mm256_setzero_si256(); |     const __m256i zero = _mm256_setzero_si256(); | ||||||
|     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); |     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); | ||||||
| @@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // multiply int8_t, add results pairwise twice and return as float vector | ||||||
|  | static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { | ||||||
|  | #if __AVXVNNIINT8__ | ||||||
|  |     const __m256i zero = _mm256_setzero_si256(); | ||||||
|  |     const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); | ||||||
|  |     return _mm256_cvtepi32_ps(summed_pairs); | ||||||
|  | #else | ||||||
|  |     // Get absolute values of x vectors | ||||||
|  |     const __m256i ax = _mm256_sign_epi8(x, x); | ||||||
|  |     // Sign the values of the y vectors | ||||||
|  |     const __m256i sy = _mm256_sign_epi8(y, x); | ||||||
|  |     return mul_sum_us8_pairs_float(ax, sy); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
| static inline __m128i packNibbles( __m256i bytes ) | static inline __m128i packNibbles( __m256i bytes ) | ||||||
| { | { | ||||||
|     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh |     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh | ||||||
| @@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { | |||||||
|     return _mm256_cvtepi32_ps(summed_pairs); |     return _mm256_cvtepi32_ps(summed_pairs); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { | ||||||
|  |     const __m128i axl = _mm256_castsi256_si128(ax); | ||||||
|  |     const __m128i axh = _mm256_extractf128_si256(ax, 1); | ||||||
|  |     const __m128i syl = _mm256_castsi256_si128(sy); | ||||||
|  |     const __m128i syh = _mm256_extractf128_si256(sy, 1); | ||||||
|  |     // Perform multiplication and create 16-bit values | ||||||
|  |     const __m128i dotl = _mm_maddubs_epi16(axl, syl); | ||||||
|  |     const __m128i doth = _mm_maddubs_epi16(axh, syh); | ||||||
|  |     return sum_i16_pairs_float(doth, dotl); | ||||||
|  | } | ||||||
|  |  | ||||||
| // multiply int8_t, add results pairwise twice and return as float vector | // multiply int8_t, add results pairwise twice and return as float vector | ||||||
| static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { | static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { | ||||||
|     const __m128i xl = _mm256_castsi256_si128(x); |     const __m128i xl = _mm256_castsi256_si128(x); | ||||||
| @@ -2434,7 +2455,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * | |||||||
|         const __m256i bx = bytes_from_nibbles_32(x[i].qs); |         const __m256i bx = bytes_from_nibbles_32(x[i].qs); | ||||||
|         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); |         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); | ||||||
|  |  | ||||||
|         const __m256 xy = mul_sum_i8_pairs_float(bx, by); |         const __m256 xy = mul_sum_us8_pairs_float(bx, by); | ||||||
|  |  | ||||||
|         // Accumulate d0*d1*x*y |         // Accumulate d0*d1*x*y | ||||||
| #if defined(__AVX2__) | #if defined(__AVX2__) | ||||||
| @@ -2906,7 +2927,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * | |||||||
|         const __m256 dy = _mm256_broadcast_ss(&y[i].d); |         const __m256 dy = _mm256_broadcast_ss(&y[i].d); | ||||||
|         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); |         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); | ||||||
|  |  | ||||||
|         const __m256 q = mul_sum_i8_pairs_float(bx, by); |         const __m256 q = mul_sum_us8_pairs_float(bx, by); | ||||||
|  |  | ||||||
|         acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); |         acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); | ||||||
|     } |     } | ||||||
| @@ -2940,7 +2961,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * | |||||||
|         const __m256 dy = _mm256_broadcast_ss(&y[i].d); |         const __m256 dy = _mm256_broadcast_ss(&y[i].d); | ||||||
|         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); |         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); | ||||||
|  |  | ||||||
|         const __m256 q = mul_sum_i8_pairs_float(bx, by); |         const __m256 q = mul_sum_us8_pairs_float(bx, by); | ||||||
|  |  | ||||||
|         acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); |         acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Ilya Kurdyukov
					Ilya Kurdyukov