mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	ggml : add AVX2 implementation of quantize_row_q4_1 (#515)
* Add AVX2 implementation of quantize_row_q4_1 * Actually use AVX2 * Make quantize_row_q4_1 static Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										91
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										91
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -688,7 +688,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) { | static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) { | ||||||
|     assert(k % QK == 0); |     assert(k % QK == 0); | ||||||
|     const int nb = k / QK; |     const int nb = k / QK; | ||||||
|  |  | ||||||
| @@ -729,6 +729,93 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) { | ||||||
|  |     assert(k % QK == 0); | ||||||
|  |  | ||||||
|  | #if defined(__AVX2__) | ||||||
|  |     const int nb = k / QK; | ||||||
|  |  | ||||||
|  |     block_q4_1 * restrict y = vy; | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < nb; i++) { | ||||||
|  |         // Load elements into 4 AVX vectors | ||||||
|  |         __m256 v0 = _mm256_loadu_ps( x ); | ||||||
|  |         __m256 v1 = _mm256_loadu_ps( x + 8 ); | ||||||
|  |         __m256 v2 = _mm256_loadu_ps( x + 16 ); | ||||||
|  |         __m256 v3 = _mm256_loadu_ps( x + 24 ); | ||||||
|  |         x += 32; | ||||||
|  |  | ||||||
|  |         // Compute max for the block | ||||||
|  |         __m256 vmax; | ||||||
|  |         vmax = _mm256_max_ps( v0, v1 ); | ||||||
|  |         vmax = _mm256_max_ps( vmax, v2 ); | ||||||
|  |         vmax = _mm256_max_ps( vmax, v3 ); | ||||||
|  |  | ||||||
|  |         __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) ); | ||||||
|  |         max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); | ||||||
|  |         max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); | ||||||
|  |         const float maxScalar = _mm_cvtss_f32( max4 ); | ||||||
|  |  | ||||||
|  |         // Compute min for the block | ||||||
|  |         __m256 vmin; | ||||||
|  |         vmin = _mm256_min_ps( v0, v1 ); | ||||||
|  |         vmin = _mm256_min_ps( vmin, v2 ); | ||||||
|  |         vmin = _mm256_min_ps( vmin, v3 ); | ||||||
|  |  | ||||||
|  |         __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) ); | ||||||
|  |         min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) ); | ||||||
|  |         min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) ); | ||||||
|  |         const float minScalar = _mm_cvtss_f32( min4 ); | ||||||
|  |  | ||||||
|  |         // Quantize these floats | ||||||
|  |         const float d = (maxScalar - minScalar) / ((1 << 4) - 1); | ||||||
|  |         const float id = d ? 1.0f/d : 0.0f; | ||||||
|  |  | ||||||
|  |         y[i].m = minScalar; | ||||||
|  |         y[i].d = d; | ||||||
|  |  | ||||||
|  |         // x = (x-min)*id | ||||||
|  |         const __m256 mul = _mm256_set1_ps( id ); | ||||||
|  |         const __m256 off = _mm256_set1_ps( minScalar ); | ||||||
|  |         v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul ); | ||||||
|  |         v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul ); | ||||||
|  |         v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul ); | ||||||
|  |         v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul ); | ||||||
|  |  | ||||||
|  |         // Round to nearest integer | ||||||
|  |         v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); | ||||||
|  |         v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); | ||||||
|  |         v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); | ||||||
|  |         v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); | ||||||
|  |  | ||||||
|  |         // Convert floats to integers | ||||||
|  |         __m256i i0 = _mm256_cvtps_epi32( v0 ); | ||||||
|  |         __m256i i1 = _mm256_cvtps_epi32( v1 ); | ||||||
|  |         __m256i i2 = _mm256_cvtps_epi32( v2 ); | ||||||
|  |         __m256i i3 = _mm256_cvtps_epi32( v3 ); | ||||||
|  |  | ||||||
|  |         // Convert int32 to int16 | ||||||
|  |         i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15 | ||||||
|  |         i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31 | ||||||
|  |                                             // Convert int16 to int8 | ||||||
|  |         i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 | ||||||
|  |  | ||||||
|  |         // We got our precious signed bytes, but the order is now wrong | ||||||
|  |         // These AVX2 pack instructions process 16-byte pieces independently | ||||||
|  |         // The following instruction is fixing the order | ||||||
|  |         const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); | ||||||
|  |         i0 = _mm256_permutevar8x32_epi32( i0, perm ); | ||||||
|  |  | ||||||
|  |         // Compress the vector into 4 bit/value, and store | ||||||
|  |         __m128i res = packNibbles( i0 ); | ||||||
|  |         _mm_storeu_si128( ( __m128i* )y[i].qs, res ); | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     // scalar | ||||||
|  |     quantize_row_q4_1_reference(x, vy, k); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
| static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { | static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { | ||||||
|     assert(k % QK == 0); |     assert(k % QK == 0); | ||||||
|     const int nb = k / QK; |     const int nb = k / QK; | ||||||
| @@ -10135,7 +10222,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * | |||||||
|     for (int j = 0; j < n; j += k) { |     for (int j = 0; j < n; j += k) { | ||||||
|         block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK; |         block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK; | ||||||
|  |  | ||||||
|         quantize_row_q4_1(src + j, y, k); |         quantize_row_q4_1_reference(src + j, y, k); | ||||||
|  |  | ||||||
|         for (int i = 0; i < nb; i++) { |         for (int i = 0; i < nb; i++) { | ||||||
|             for (int l = 0; l < QK; l += 2) { |             for (int l = 0; l < QK; l += 2) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren