mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Add AVX2 implementation of dequantize_row_q4_1 (#505)
This commit is contained in:
		
							
								
								
									
										34
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -783,7 +783,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { | |||||||
|  |  | ||||||
|             // Scale and store |             // Scale and store | ||||||
|             for (int j = 0; j < 4; j++) { |             for (int j = 0; j < 4; j++) { | ||||||
|                 __m256 result = _mm256_mul_ps(vf[j], d_v); |                 const __m256 result = _mm256_mul_ps(vf[j], d_v); | ||||||
|                 _mm256_storeu_ps(y + i * QK + l + j*8, result); |                 _mm256_storeu_ps(y + i * QK + l + j*8, result); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -879,6 +879,37 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { | |||||||
|     const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float)); |     const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float)); | ||||||
|     const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float)); |     const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float)); | ||||||
|  |  | ||||||
|  | #if defined(__AVX2__) | ||||||
|  |     for (int i = 0; i < nb; i++) { | ||||||
|  |         const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs)); | ||||||
|  |         const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs)); | ||||||
|  |  | ||||||
|  |         const uint8_t * restrict pp = pb + i*bs; | ||||||
|  |  | ||||||
|  |         for (int l = 0; l < QK; l += 32) { | ||||||
|  |             // Load 32x4-bit integers into 32x8-bit integers | ||||||
|  |             __m256i vx8 = bytesFromNibbles(pp+l/2); | ||||||
|  |  | ||||||
|  |             // Convert to 16-bit int | ||||||
|  |             const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); | ||||||
|  |             const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); | ||||||
|  |  | ||||||
|  |             // Convert to 32-bit int -> float 32 | ||||||
|  |             const __m256 vf[4] = { | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |             // Scale, add m and store | ||||||
|  |             for (int j = 0; j < 4; j++) { | ||||||
|  |                 const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); | ||||||
|  |                 _mm256_storeu_ps(y + i * QK + l + j*8, result); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | #else | ||||||
|     for (int i = 0; i < nb; i++) { |     for (int i = 0; i < nb; i++) { | ||||||
|         const float d = *(const float *) (pd + i*bs); |         const float d = *(const float *) (pd + i*bs); | ||||||
|         const float m = *(const float *) (pm + i*bs); |         const float m = *(const float *) (pm + i*bs); | ||||||
| @@ -901,6 +932,7 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { | |||||||
|             assert(!isnan(y[i*QK + l + 1])); |             assert(!isnan(y[i*QK + l + 1])); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| // | // | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren