mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml: optimize some vec dot functions for LoongArch ASX (#11842)
* Optimize ggml_vec_dot_q3_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q4_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q6_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q5_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q2_K_q8_K for LoongArch ASX * Optimize mul_sum_i8_pairs_float for LoongArch ASX * Optimize ggml_vec_dot_iq4_xs_q8_K for LoongArch ASX
This commit is contained in:
		| @@ -562,6 +562,41 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) { | |||||||
|     return __lasx_xvpickev_b(tmp1, tmp); |     return __lasx_xvpickev_b(tmp1, tmp); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) { | ||||||
|  |     __m256i tmp1, tmp2; | ||||||
|  |     tmp1 = __lasx_xvmulwev_h_b(a, b); | ||||||
|  |     tmp2 = __lasx_xvmulwod_h_b(a, b); | ||||||
|  |     return __lasx_xvadd_h(tmp1, tmp2); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) { | ||||||
|  |     switch (b) { | ||||||
|  |         case 0: return __lasx_xvrepl128vei_h(a, 0); | ||||||
|  |         case 1: return __lasx_xvrepl128vei_h(a, 1); | ||||||
|  |         case 2: return __lasx_xvrepl128vei_h(a, 2); | ||||||
|  |         case 3: return __lasx_xvrepl128vei_h(a, 3); | ||||||
|  |         case 4: return __lasx_xvrepl128vei_h(a, 4); | ||||||
|  |         case 5: return __lasx_xvrepl128vei_h(a, 5); | ||||||
|  |         case 6: return __lasx_xvrepl128vei_h(a, 6); | ||||||
|  |         case 7: return __lasx_xvrepl128vei_h(a, 7); | ||||||
|  |         default: __builtin_unreachable(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) { | ||||||
|  |     switch (b) { | ||||||
|  |         case 0: return __lasx_xvandi_b(a, 1 << 0); | ||||||
|  |         case 1: return __lasx_xvandi_b(a, 1 << 1); | ||||||
|  |         case 2: return __lasx_xvandi_b(a, 1 << 2); | ||||||
|  |         case 3: return __lasx_xvandi_b(a, 1 << 3); | ||||||
|  |         case 4: return __lasx_xvandi_b(a, 1 << 4); | ||||||
|  |         case 5: return __lasx_xvandi_b(a, 1 << 5); | ||||||
|  |         case 6: return __lasx_xvandi_b(a, 1 << 6); | ||||||
|  |         case 7: return __lasx_xvandi_b(a, 1 << 7); | ||||||
|  |         default: __builtin_unreachable(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| // multiply int8_t, add results pairwise twice | // multiply int8_t, add results pairwise twice | ||||||
| static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { | static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { | ||||||
|     // Get absolute values of x vectors |     // Get absolute values of x vectors | ||||||
| @@ -656,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) | |||||||
|  |  | ||||||
| // multiply int8_t, add results pairwise twice and return as float vector | // multiply int8_t, add results pairwise twice and return as float vector | ||||||
| static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { | static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { | ||||||
|  |     const __m256i dot = lasx_madd_h_b(x, y); | ||||||
|     // Get absolute values of x vectors |     return sum_i16_pairs_float(dot); | ||||||
|     const __m256i ax = __lasx_xvsigncov_b(x, x); |  | ||||||
|     // Sign the values of the y vectors |  | ||||||
|     const __m256i sy = __lasx_xvsigncov_b(x, y); |  | ||||||
|  |  | ||||||
|     return mul_sum_us8_pairs_float(ax, sy); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| static inline __m128i packNibbles( __m256i bytes ) { | static inline __m128i packNibbles( __m256i bytes ) { | ||||||
| @@ -4965,9 +4995,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|  |  | ||||||
| #elif defined __loongarch_asx | #elif defined __loongarch_asx | ||||||
|  |  | ||||||
|     const __m256i m3 = __lasx_xvreplgr2vr_b(3); |  | ||||||
|     const __m128i m4 = __lsx_vreplgr2vr_b(0xF); |  | ||||||
|  |  | ||||||
|     __m256 acc = (__m256)__lasx_xvldi(0); |     __m256 acc = (__m256)__lasx_xvldi(0); | ||||||
|  |  | ||||||
|     for (int i = 0; i < nb; ++i) { |     for (int i = 0; i < nb; ++i) { | ||||||
| @@ -4978,18 +5005,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|         const uint8_t * restrict q2 = x[i].qs; |         const uint8_t * restrict q2 = x[i].qs; | ||||||
|         const int8_t  * restrict q8 = y[i].qs; |         const int8_t  * restrict q8 = y[i].qs; | ||||||
|  |  | ||||||
|         const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); |         const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); | ||||||
|         const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); |         const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf); | ||||||
|         const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); |         const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4)); | ||||||
|         const __m256i mins = lasx_ext8_16(mins8); |  | ||||||
|         const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); |         const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); | ||||||
|  |  | ||||||
|         acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); |         acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); | ||||||
|  |  | ||||||
|         const __m256i all_scales = lasx_ext8_16(scales8); |         const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; | ||||||
|         const __m128i l_scales = lasx_extracti128(all_scales, 0); |         const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); | ||||||
|         const __m128i h_scales = lasx_extracti128(all_scales, 1); |  | ||||||
|         const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; |  | ||||||
|  |  | ||||||
|         __m256i sumi = __lasx_xvldi(0); |         __m256i sumi = __lasx_xvldi(0); | ||||||
|  |  | ||||||
| @@ -5002,20 +5026,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|  |  | ||||||
|             const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); |             const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3); | ||||||
|             const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); |             const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3); | ||||||
|             const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); |             const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3); | ||||||
|             const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); |             const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6); | ||||||
|  |  | ||||||
|             __m256i p0 = lasx_maddubs_h(q2_0, q8_0); |             __m256i p0 = lasx_madd_h_b(q2_0, q8_0); | ||||||
|             __m256i p1 = lasx_maddubs_h(q2_1, q8_1); |             __m256i p1 = lasx_madd_h_b(q2_1, q8_1); | ||||||
|             __m256i p2 = lasx_maddubs_h(q2_2, q8_2); |             __m256i p2 = lasx_madd_h_b(q2_2, q8_2); | ||||||
|             __m256i p3 = lasx_maddubs_h(q2_3, q8_3); |             __m256i p3 = lasx_madd_h_b(q2_3, q8_3); | ||||||
|  |  | ||||||
|             p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); |             p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0); | ||||||
|             p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); |             p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1); | ||||||
|             p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); |             p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2); | ||||||
|             p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); |             p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3); | ||||||
|  |  | ||||||
|             p0 = __lasx_xvadd_w(p0, p1); |             p0 = __lasx_xvadd_w(p0, p1); | ||||||
|             p2 = __lasx_xvadd_w(p2, p3); |             p2 = __lasx_xvadd_w(p2, p3); | ||||||
| @@ -5771,8 +5795,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|  |  | ||||||
| #elif defined __loongarch_asx | #elif defined __loongarch_asx | ||||||
|  |  | ||||||
|     const __m256i m3 = __lasx_xvreplgr2vr_b(3); |  | ||||||
|     const __m256i mone = __lasx_xvreplgr2vr_b(1); |  | ||||||
|     const __m128i m32 = __lsx_vreplgr2vr_b(32); |     const __m128i m32 = __lsx_vreplgr2vr_b(32); | ||||||
|  |  | ||||||
|     __m256 acc = (__m256)__lasx_xvldi(0); |     __m256 acc = (__m256)__lasx_xvldi(0); | ||||||
| @@ -5792,10 +5814,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|                 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), |                 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), | ||||||
|                 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); |                 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); | ||||||
|         scales128 = __lsx_vsub_b(scales128, m32); |         scales128 = __lsx_vsub_b(scales128, m32); | ||||||
|         const __m256i all_scales = lasx_ext8_16(scales128); |  | ||||||
|         const __m128i l_scales = lasx_extracti128(all_scales, 0); |         const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; | ||||||
|         const __m128i h_scales = lasx_extracti128(all_scales, 1); |         const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); | ||||||
|         const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; |  | ||||||
|  |  | ||||||
|         // high bit |         // high bit | ||||||
|         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); |         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); | ||||||
| @@ -5803,35 +5824,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|         // integer accumulator |         // integer accumulator | ||||||
|         __m256i sumi = __lasx_xvldi(0); |         __m256i sumi = __lasx_xvldi(0); | ||||||
|  |  | ||||||
|         int bit = 0; |  | ||||||
|         int is  = 0; |  | ||||||
|         __m256i xvbit; |  | ||||||
|  |  | ||||||
|  |  | ||||||
|         for (int j = 0; j < QK_K/128; ++j) { |         for (int j = 0; j < QK_K/128; ++j) { | ||||||
|             // load low 2 bits |             // load low 2 bits | ||||||
|             const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; |             const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; | ||||||
|  |  | ||||||
|             xvbit = __lasx_xvreplgr2vr_h(bit); |  | ||||||
|             // prepare low and high bits |             // prepare low and high bits | ||||||
|             const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); |             const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3); | ||||||
|             const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); |             const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3); | ||||||
|             ++bit; |             const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3); | ||||||
|  |             const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6); | ||||||
|             xvbit = __lasx_xvreplgr2vr_h(bit); |             const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2); | ||||||
|             const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); |             const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2); | ||||||
|             const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); |             const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2); | ||||||
|             ++bit; |             const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2); | ||||||
|  |             const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0); | ||||||
|             xvbit = __lasx_xvreplgr2vr_h(bit); |             const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1); | ||||||
|             const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); |             const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2); | ||||||
|             const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); |             const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3); | ||||||
|             ++bit; |  | ||||||
|  |  | ||||||
|             xvbit = __lasx_xvreplgr2vr_h(bit); |  | ||||||
|             const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); |  | ||||||
|             const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); |  | ||||||
|             ++bit; |  | ||||||
|  |  | ||||||
|             // load Q8 quants |             // load Q8 quants | ||||||
|             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
| @@ -5839,29 +5848,16 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|  |  | ||||||
|             // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, |             __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0); | ||||||
|             // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, |             __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1); | ||||||
|             // and 2 if the high bit was set) |             __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2); | ||||||
|             __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); |             __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3); | ||||||
|             __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); |  | ||||||
|             __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); |  | ||||||
|             __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); |  | ||||||
|  |  | ||||||
|             __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); |  | ||||||
|             __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); |  | ||||||
|             __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); |  | ||||||
|             __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); |  | ||||||
|  |  | ||||||
|             p16_0 = __lasx_xvsub_h(p16_0, q8s_0); |  | ||||||
|             p16_1 = __lasx_xvsub_h(p16_1, q8s_1); |  | ||||||
|             p16_2 = __lasx_xvsub_h(p16_2, q8s_2); |  | ||||||
|             p16_3 = __lasx_xvsub_h(p16_3, q8s_3); |  | ||||||
|  |  | ||||||
|             // multiply with scales |             // multiply with scales | ||||||
|             p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); |             p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); | ||||||
|             p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); |             p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); | ||||||
|             p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); |             p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); | ||||||
|             p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); |             p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); | ||||||
|  |  | ||||||
|             // accumulate |             // accumulate | ||||||
|             p16_0 = __lasx_xvadd_w(p16_0, p16_1); |             p16_0 = __lasx_xvadd_w(p16_0, p16_1); | ||||||
| @@ -5869,7 +5865,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|             sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); |             sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); | ||||||
|         } |         } | ||||||
|         // multiply with block scale and accumulate |         // multiply with block scale and accumulate | ||||||
|         acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME |         acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     *s = hsum_float_8(acc); |     *s = hsum_float_8(acc); | ||||||
| @@ -6562,11 +6558,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|     *s = vec_extract(vsumf0, 0); |     *s = vec_extract(vsumf0, 0); | ||||||
|  |  | ||||||
| #elif defined __loongarch_asx | #elif defined __loongarch_asx | ||||||
|     GGML_UNUSED(kmask1); |  | ||||||
|     GGML_UNUSED(kmask2); |  | ||||||
|     GGML_UNUSED(kmask3); |  | ||||||
|  |  | ||||||
|     const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); |  | ||||||
|  |  | ||||||
|     __m256 acc = (__m256)__lasx_xvldi(0); |     __m256 acc = (__m256)__lasx_xvldi(0); | ||||||
|     __m128 acc_m = (__m128)__lsx_vldi(0); |     __m128 acc_m = (__m128)__lsx_vldi(0); | ||||||
| @@ -6586,33 +6577,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|         const uint8_t * restrict q4 = x[i].qs; |         const uint8_t * restrict q4 = x[i].qs; | ||||||
|         const int8_t  * restrict q8 = y[i].qs; |         const int8_t  * restrict q8 = y[i].qs; | ||||||
|  |  | ||||||
|         const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); |         const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); | ||||||
|  |         const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); | ||||||
|  |         const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); | ||||||
|  |  | ||||||
|         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); |         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); | ||||||
|         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); |         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); | ||||||
|         const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); |         const __m128i prod = lsx_madd_h(mins128, q8s); | ||||||
|         acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); |         acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); | ||||||
|  |  | ||||||
|         const __m128i sc128  = lasx_extracti128(mins_and_scales, 0); |         const __m256i scales = lasx_insertf128(scales128, scales128); | ||||||
|         const __m256i scales = lasx_insertf128(sc128, sc128); |  | ||||||
|  |  | ||||||
|         __m256i sumi = __lasx_xvldi(0); |         __m256i sumi = __lasx_xvldi(0); | ||||||
|  |  | ||||||
|         for (int j = 0; j < QK_K/64; ++j) { |         for (int j = 0; j < QK_K/64; ++j) { | ||||||
|  |  | ||||||
|             const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); |             const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0); | ||||||
|             const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); |             const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1); | ||||||
|  |  | ||||||
|             const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; |             const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; | ||||||
|             const __m256i q4l = __lasx_xvand_v(q4bits, m4); |             const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf); | ||||||
|             const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); |             const __m256i q4h = __lasx_xvsrli_b(q4bits, 4); | ||||||
|  |  | ||||||
|             const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             __m256i p16l = lasx_maddubs_h(q4l, q8l); |             __m256i p16l = lasx_madd_h_b(q4l, q8l); | ||||||
|             p16l = lasx_madd_h(scale_l, p16l); |             p16l = lasx_madd_h(scale_l, p16l); | ||||||
|  |  | ||||||
|             const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             __m256i p16h = lasx_maddubs_h(q4h, q8h); |             __m256i p16h = lasx_madd_h_b(q4h, q8h); | ||||||
|             p16h = lasx_madd_h(scale_h, p16h); |             p16h = lasx_madd_h(scale_h, p16h); | ||||||
|             const __m256i sumj = __lasx_xvadd_w(p16l, p16h); |             const __m256i sumj = __lasx_xvadd_w(p16l, p16h); | ||||||
|  |  | ||||||
| @@ -7289,19 +7281,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|     *s = vec_extract(vsumf0, 0); |     *s = vec_extract(vsumf0, 0); | ||||||
|  |  | ||||||
| #elif defined __loongarch_asx | #elif defined __loongarch_asx | ||||||
|     GGML_UNUSED(kmask1); |  | ||||||
|     GGML_UNUSED(kmask2); |  | ||||||
|     GGML_UNUSED(kmask3); |  | ||||||
|  |  | ||||||
|     const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); |  | ||||||
|     const __m128i mzero = __lsx_vldi(0); |  | ||||||
|     const __m256i mone  = __lasx_xvreplgr2vr_b(1); |  | ||||||
|  |  | ||||||
|     __m256 acc = (__m256)__lasx_xvldi(0); |     __m256 acc = (__m256)__lasx_xvldi(0); | ||||||
|  |     __m128 acc_m = (__m128)__lsx_vldi(0); | ||||||
|  |  | ||||||
|     float summs = 0.f; |     for (int i = 0; i < nb; ++i) { | ||||||
|  |  | ||||||
|    for (int i = 0; i < nb; ++i) { |  | ||||||
|  |  | ||||||
|         const uint8_t * restrict q5 = x[i].qs; |         const uint8_t * restrict q5 = x[i].qs; | ||||||
|         const int8_t  * restrict q8 = y[i].qs; |         const int8_t  * restrict q8 = y[i].qs; | ||||||
| @@ -7316,49 +7300,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|         utmp[2] = uaux; |         utmp[2] = uaux; | ||||||
|         utmp[0] &= kmask1; |         utmp[0] &= kmask1; | ||||||
|  |  | ||||||
|         const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); |         const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); | ||||||
|  |         const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); | ||||||
|  |         const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); | ||||||
|  |  | ||||||
|         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); |         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); | ||||||
|         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); |         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); | ||||||
|         const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); |         const __m128i prod = lsx_madd_h(mins128, q8s); | ||||||
|         const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); |         acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); | ||||||
|         summs += dmin * __lsx_vpickve2gr_w(hsum, 0);    //TODO check |  | ||||||
|  |  | ||||||
|         const __m128i sc128  = lasx_extracti128(mins_and_scales, 0); |         const __m256i scales = lasx_insertf128(scales128, scales128); | ||||||
|         const __m256i scales = lasx_insertf128(sc128, sc128); |  | ||||||
|  |  | ||||||
|         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); |         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); | ||||||
|         __m256i hmask = mone; |  | ||||||
|  |  | ||||||
|         __m256i sumi = __lasx_xvldi(0); |         __m256i sumi = __lasx_xvldi(0); | ||||||
|  |  | ||||||
|         int bit = 0; |  | ||||||
|         __m256i xvbit; |  | ||||||
|  |  | ||||||
|         for (int j = 0; j < QK_K/64; ++j) { |         for (int j = 0; j < QK_K/64; ++j) { | ||||||
|  |  | ||||||
|             const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); |             const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0); | ||||||
|             const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); |             const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1); | ||||||
|  |  | ||||||
|             const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; |             const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; | ||||||
|  |  | ||||||
|             xvbit = __lasx_xvreplgr2vr_h(bit++); |             const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf); | ||||||
|             const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); |             const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4); | ||||||
|             const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); |             const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef); | ||||||
|             const __m256i q5_0  = __lasx_xvadd_b(q5l_0, q5h_0); |             const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef); | ||||||
|             hmask = __lasx_xvslli_h(hmask, 1); |             const __m256i q5_0  = __lasx_xvor_v(q5l_0, q5h_0); | ||||||
|  |             const __m256i q5_1  = __lasx_xvor_v(q5l_1, q5h_1); | ||||||
|             xvbit = __lasx_xvreplgr2vr_h(bit++); |  | ||||||
|             const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); |  | ||||||
|             const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); |  | ||||||
|             const __m256i q5_1  = __lasx_xvadd_b(q5l_1, q5h_1); |  | ||||||
|             hmask = __lasx_xvslli_h(hmask, 1); |  | ||||||
|  |  | ||||||
|             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|  |  | ||||||
|             __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); |             __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0); | ||||||
|             __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); |             __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1); | ||||||
|  |  | ||||||
|             p16_0 = lasx_madd_h(scale_0, p16_0); |             p16_0 = lasx_madd_h(scale_0, p16_0); | ||||||
|             p16_1 = lasx_madd_h(scale_1, p16_1); |             p16_1 = lasx_madd_h(scale_1, p16_1); | ||||||
| @@ -7372,7 +7347,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|  |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     *s = hsum_float_8(acc) + summs; |     acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8)); | ||||||
|  |     acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4)); | ||||||
|  |  | ||||||
|  |     *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; | ||||||
|  |  | ||||||
| #else | #else | ||||||
|  |  | ||||||
| @@ -8033,8 +8011,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|  |  | ||||||
| #elif defined __loongarch_asx | #elif defined __loongarch_asx | ||||||
|  |  | ||||||
|     const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); |  | ||||||
|     const __m256i m2 = __lasx_xvreplgr2vr_b(3); |  | ||||||
|     const __m256i m32s = __lasx_xvreplgr2vr_b(32); |     const __m256i m32s = __lasx_xvreplgr2vr_b(32); | ||||||
|  |  | ||||||
|     __m256 acc = (__m256)__lasx_xvldi(0); |     __m256 acc = (__m256)__lasx_xvldi(0); | ||||||
| @@ -8047,58 +8023,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r | |||||||
|         const uint8_t * restrict qh = x[i].qh; |         const uint8_t * restrict qh = x[i].qh; | ||||||
|         const int8_t  * restrict q8 = y[i].qs; |         const int8_t  * restrict q8 = y[i].qs; | ||||||
|  |  | ||||||
|         const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); |         const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); | ||||||
|  |         const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; | ||||||
|  |         const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); | ||||||
|  |  | ||||||
|         __m256i sumi = __lasx_xvldi(0); |         __m256i sumi = __lasx_xvldi(0); | ||||||
|  |  | ||||||
|         int is = 0; |  | ||||||
|  |  | ||||||
|         for (int j = 0; j < QK_K/128; ++j) { |         for (int j = 0; j < QK_K/128; ++j) { | ||||||
|  |  | ||||||
|             const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); |  | ||||||
|             const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); |  | ||||||
|             const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); |  | ||||||
|             const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); |  | ||||||
|             is += 4; |  | ||||||
|  |  | ||||||
|             const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; |             const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; | ||||||
|             const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; |             const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; | ||||||
|             const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; |             const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; | ||||||
|  |  | ||||||
|             const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); |             const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4); | ||||||
|             const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); |             const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2); | ||||||
|             const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); |             const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4); | ||||||
|             const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); |             const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2); | ||||||
|  |  | ||||||
|             const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); |             const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0); | ||||||
|             const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); |             const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1); | ||||||
|             const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); |             const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2); | ||||||
|             const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); |             const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3); | ||||||
|  |  | ||||||
|             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; |             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; | ||||||
|  |  | ||||||
|             __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); |             __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0); | ||||||
|             __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); |             __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1); | ||||||
|             __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); |             __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2); | ||||||
|             __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); |             __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3); | ||||||
|  |  | ||||||
|             __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); |             p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); | ||||||
|             __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); |             p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); | ||||||
|             __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); |             p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); | ||||||
|             __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); |             p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); | ||||||
|  |  | ||||||
|             p16_0 = __lasx_xvsub_h(p16_0, q8s_0); |  | ||||||
|             p16_1 = __lasx_xvsub_h(p16_1, q8s_1); |  | ||||||
|             p16_2 = __lasx_xvsub_h(p16_2, q8s_2); |  | ||||||
|             p16_3 = __lasx_xvsub_h(p16_3, q8s_3); |  | ||||||
|  |  | ||||||
|             p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); |  | ||||||
|             p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); |  | ||||||
|             p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); |  | ||||||
|             p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); |  | ||||||
|  |  | ||||||
|             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); |             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); | ||||||
|             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); |             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); | ||||||
| @@ -10423,13 +10383,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { | |||||||
| } | } | ||||||
| #elif defined(__loongarch_asx) | #elif defined(__loongarch_asx) | ||||||
| static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { | static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { | ||||||
|     const __m256i ax = __lasx_xvsigncov_b(x, x); |     const __m256i a = __lasx_xvmulwev_h_b(x, y); | ||||||
|     const __m256i sy = __lasx_xvsigncov_b(x, y); |     const __m256i b = __lasx_xvmulwod_h_b(x, y); | ||||||
|     __m256i tmp1, tmp2, tmp3; |     return __lasx_xvadd_h(a, b); | ||||||
|     tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); |  | ||||||
|     tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); |  | ||||||
|     tmp3 = __lasx_xvadd_h(tmp1, tmp2); |  | ||||||
|     return __lasx_xvsat_h(tmp3, 15); |  | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| @@ -11479,67 +11435,31 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * | |||||||
| #elif defined(__loongarch_asx) | #elif defined(__loongarch_asx) | ||||||
|  |  | ||||||
|     const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); |     const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); | ||||||
|     const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f); |  | ||||||
|  |  | ||||||
|     __m256 accum = (__m256)__lasx_xvldi(0); |     __m256 accum = (__m256)__lasx_xvldi(0); | ||||||
|     __m256i tmp1; |  | ||||||
|     __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; |  | ||||||
|  |  | ||||||
|     mask_8f = __lsx_vreplgr2vr_b(0x8f); |  | ||||||
|     for (int ibl = 0; ibl < nb; ++ibl) { |     for (int ibl = 0; ibl < nb; ++ibl) { | ||||||
|         const uint8_t * qs = x[ibl].qs; |         const uint8_t * qs = x[ibl].qs; | ||||||
|         const int8_t  * q8 = y[ibl].qs; |         const int8_t  * q8 = y[ibl].qs; | ||||||
|         uint16_t sh = x[ibl].scales_h; |         uint16_t sh = x[ibl].scales_h; | ||||||
|         __m256i sumi1 = __lasx_xvldi(0); |         __m256i sumi1 = __lasx_xvldi(0); | ||||||
|         __m256i sumi2 = __lasx_xvldi(0); |         __m256i sumi2 = __lasx_xvldi(0); | ||||||
|         __m128i zero = __lsx_vldi(0); |  | ||||||
|         for (int ib = 0; ib < QK_K/32; ib += 2) { |         for (int ib = 0; ib < QK_K/32; ib += 2) { | ||||||
|             const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0);  qs += 16; |             const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; | ||||||
|             const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0);  qs += 16; |             const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; | ||||||
|             const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; |             const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; | ||||||
|             const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; |             const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; | ||||||
|             tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); |             const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)), | ||||||
|             tmp0 = __lsx_vori_b(tmp2, 0x10); |                                                   __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf))); | ||||||
|             mask = __lsx_vsle_b(zero, tmp2); |             const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)), | ||||||
|             tmp3 = __lsx_vand_v(tmp0, mask); |                                                   __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf))); | ||||||
|             tmp3 = __lsx_vshuf_b(values128, zero, tmp3); |  | ||||||
|  |  | ||||||
|             tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); |  | ||||||
|             tmp0 = __lsx_vori_b(tmp2, 0x10); |  | ||||||
|             mask = __lsx_vsle_b(zero, tmp2); |  | ||||||
|             tmp4 = __lsx_vand_v(tmp0, mask); |  | ||||||
|             tmp4 = __lsx_vshuf_b(values128, zero, tmp4); |  | ||||||
|  |  | ||||||
|             const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); |  | ||||||
|  |  | ||||||
|             tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); |  | ||||||
|             tmp0 = __lsx_vori_b(tmp2, 0x10); |  | ||||||
|             mask = __lsx_vsle_b(zero, tmp2); |  | ||||||
|             tmp3 = __lsx_vand_v(tmp0, mask); |  | ||||||
|             tmp3 = __lsx_vshuf_b(values128, zero, tmp3); |  | ||||||
|  |  | ||||||
|             tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); |  | ||||||
|             tmp0 = __lsx_vori_b(tmp2, 0x10); |  | ||||||
|             mask = __lsx_vsle_b(zero, tmp2); |  | ||||||
|             tmp4 = __lsx_vand_v(tmp0, mask); |  | ||||||
|             tmp4 = __lsx_vshuf_b(values128, zero, tmp4); |  | ||||||
|  |  | ||||||
|             const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); |  | ||||||
|  |  | ||||||
|             const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); |             const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); | ||||||
|             const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); |             const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); | ||||||
|             const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; |             const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; | ||||||
|             const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32; |             const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32; | ||||||
|             sh >>= 4; |             sh >>= 4; | ||||||
|             __m256i tmp5, tmp6; |             const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1)); | ||||||
|             tmp1 = __lasx_xvreplgr2vr_h(ls1); |             const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2)); | ||||||
|             tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); |  | ||||||
|             tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); |  | ||||||
|             const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); |  | ||||||
|             tmp1 = __lasx_xvreplgr2vr_h(ls2); |  | ||||||
|             tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); |  | ||||||
|             tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); |  | ||||||
|             const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); |  | ||||||
|             sumi1 = __lasx_xvadd_w(p_1, sumi1); |             sumi1 = __lasx_xvadd_w(p_1, sumi1); | ||||||
|             sumi2 = __lasx_xvadd_w(p_2, sumi2); |             sumi2 = __lasx_xvadd_w(p_2, sumi2); | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jinyang He
					Jinyang He