mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Faster AVX2 dot product for IQ2_XS (#5187)
* iq2xs: faster AVX2 dot product * iq2xs: small AVX2 imrovement * Speed up computing sign bits in AVX2 iq2_xs dot product --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Peter Reid <peter@peterreid.net>
This commit is contained in:
		| @@ -8525,17 +8525,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest | |||||||
|  |  | ||||||
|     const __m128i m4 = _mm_set1_epi8(0xf); |     const __m128i m4 = _mm_set1_epi8(0xf); | ||||||
|     const __m128i m1 = _mm_set1_epi8(1); |     const __m128i m1 = _mm_set1_epi8(1); | ||||||
|     const __m128i m511 = _mm_set1_epi16(511); |     const __m256i m511 = _mm256_set1_epi16(511); | ||||||
|     const __m128i m127 = _mm_set1_epi16(127); |     const __m256i mone = _mm256_set1_epi8(1); | ||||||
|  |  | ||||||
|     const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; |     static const uint8_t k_bit_helper[32] = { | ||||||
|  |         0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, | ||||||
|  |         0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, | ||||||
|  |     }; | ||||||
|  |     static const char block_sign_shuffle_mask_1[32] = { | ||||||
|  |         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, | ||||||
|  |         0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, | ||||||
|  |     }; | ||||||
|  |     static const char block_sign_shuffle_mask_2[32] = { | ||||||
|  |         0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, | ||||||
|  |         0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, | ||||||
|  |     }; | ||||||
|  |     static const uint8_t bit_selector_mask_bytes[32] = { | ||||||
|  |         0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, | ||||||
|  |         0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); | ||||||
|  |     const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); | ||||||
|  |     const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); | ||||||
|  |     const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); | ||||||
|  |  | ||||||
|     uint64_t aux64; |     uint64_t aux64; | ||||||
|  |  | ||||||
|     // somewhat hacky, but gives a significant boost in performance |     // somewhat hacky, but gives a significant boost in performance | ||||||
|     __m128i aux_gindex, aux_sindex; |     __m256i aux_gindex; | ||||||
|     const uint16_t * gindex = (const uint16_t *)&aux_gindex; |     const uint16_t * gindex = (const uint16_t *)&aux_gindex; | ||||||
|     const uint16_t * sindex = (const uint16_t *)&aux_sindex; |  | ||||||
|  |  | ||||||
|     __m256 accumf = _mm256_setzero_ps(); |     __m256 accumf = _mm256_setzero_ps(); | ||||||
|     for (int i = 0; i < nb; ++i) { |     for (int i = 0; i < nb; ++i) { | ||||||
| @@ -8550,26 +8569,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest | |||||||
|  |  | ||||||
|         __m256i sumi1 = _mm256_setzero_si256(); |         __m256i sumi1 = _mm256_setzero_si256(); | ||||||
|         __m256i sumi2 = _mm256_setzero_si256(); |         __m256i sumi2 = _mm256_setzero_si256(); | ||||||
|         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { |         for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { | ||||||
|  |  | ||||||
|  |             const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2);  q2 += 16; | ||||||
|  |             aux_gindex = _mm256_and_si256(q2_data, m511); | ||||||
|  |  | ||||||
|  |             const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); | ||||||
|  |             const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); | ||||||
|  |             const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); | ||||||
|  |  | ||||||
|  |             const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); | ||||||
|  |             const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); | ||||||
|  |  | ||||||
|             const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; |             const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; | ||||||
|             const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; |             const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; | ||||||
|             const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2);  q2 +=  8; |             const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; | ||||||
|             aux_gindex = _mm_and_si128(q2_data, m511); |             const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; | ||||||
|             aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127); |  | ||||||
|             const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); |             const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], | ||||||
|             const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); |                                                    iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); | ||||||
|             const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]); |             const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], | ||||||
|             const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]); |                                                    iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); | ||||||
|             const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); |             const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], | ||||||
|             const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); |                                                    iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); | ||||||
|  |             const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], | ||||||
|  |                                                    iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); | ||||||
|  |  | ||||||
|  |             const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); | ||||||
|  |             const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); | ||||||
|  |             const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l); | ||||||
|  |             const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h); | ||||||
|  |  | ||||||
|  |             __m256i signs; | ||||||
|  |             signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); | ||||||
|  |             signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); | ||||||
|  |             const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); | ||||||
|  |  | ||||||
|  |             signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); | ||||||
|  |             signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); | ||||||
|  |             const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); | ||||||
|  |  | ||||||
|  |             signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); | ||||||
|  |             signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); | ||||||
|  |             const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); | ||||||
|  |  | ||||||
|  |             signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); | ||||||
|  |             signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); | ||||||
|  |             const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); | ||||||
|  |  | ||||||
|             const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1); |             const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1); | ||||||
|             const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2); |             const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2); | ||||||
|  |             const __m256i dot3  = _mm256_maddubs_epi16(q2_3, q8s_3); | ||||||
|  |             const __m256i dot4  = _mm256_maddubs_epi16(q2_4, q8s_4); | ||||||
|  |  | ||||||
|             const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); |             const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); | ||||||
|             const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); |             const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); | ||||||
|  |             const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); | ||||||
|  |             const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); | ||||||
|  |  | ||||||
|             sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); |             sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); | ||||||
|             sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); |             sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); | ||||||
|  |             sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); | ||||||
|  |             sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); |         accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Kawrakow
					Kawrakow