mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	ggml : fix UB in IQ2_S and IQ3_S (#6012)
This commit is contained in:
		| @@ -9025,7 +9025,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * | |||||||
|                                      vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); |                                      vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); | ||||||
|             qs += 8; |             qs += 8; | ||||||
|  |  | ||||||
|             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); |             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); | ||||||
|             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); |             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||||||
|             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); |             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||||||
|             vs.val[0] = vceqq_u8(vs.val[0], mask2); |             vs.val[0] = vceqq_u8(vs.val[0], mask2); | ||||||
| @@ -9034,7 +9034,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * | |||||||
|             q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); |             q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); | ||||||
|             q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); |             q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); | ||||||
|  |  | ||||||
|             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16))); |             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); | ||||||
|             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); |             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||||||
|             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); |             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||||||
|             vs.val[0] = vceqq_u8(vs.val[0], mask2); |             vs.val[0] = vceqq_u8(vs.val[0], mask2); | ||||||
| @@ -9105,12 +9105,12 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * | |||||||
|                                                    iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); |                                                    iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); | ||||||
|             qs += 8; |             qs += 8; | ||||||
|  |  | ||||||
|             __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); |             __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); | ||||||
|             aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); |             aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); | ||||||
|             const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); |             const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); | ||||||
|             const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); |             const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); | ||||||
|  |  | ||||||
|             aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); |             aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); | ||||||
|             aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); |             aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); | ||||||
|             const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); |             const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); | ||||||
|             const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); |             const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); | ||||||
| @@ -9386,7 +9386,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * | |||||||
|                                                         iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); |                                                         iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); | ||||||
|  |  | ||||||
|  |  | ||||||
|             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); |             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); | ||||||
|             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); |             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||||||
|             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); |             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||||||
|             vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); |             vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); | ||||||
| @@ -9395,7 +9395,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * | |||||||
|             q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); |             q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); | ||||||
|             q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); |             q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); | ||||||
|  |  | ||||||
|             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16))); |             vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); | ||||||
|             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); |             vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); | ||||||
|             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); |             vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); | ||||||
|             vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); |             vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov