mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml: aarch64: implement SVE kernels for q2_k_q8_k vector dot (#12064)
* Added SVE Support for Q2_K Quantized Models * Use 4-space indentation in the switch cases * removed comments lines * Remove the loop Retain the curly bracess for better understanding of code * Remove the comment like added for q3_k_q8_k kernel --------- Co-authored-by: vithulep <p.m.vithule1517@gmail.com>
This commit is contained in:
		@@ -4587,7 +4587,252 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    const int nb = n / QK_K;
 | 
					    const int nb = n / QK_K;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef __ARM_NEON
 | 
					#ifdef __ARM_FEATURE_SVE
 | 
				
			||||||
 | 
					    const int vector_length = svcntb()*8;
 | 
				
			||||||
 | 
					    const svuint8_t m3s = svdup_n_u8(0x3);
 | 
				
			||||||
 | 
					    const svuint32_t m4s = svdup_n_u32(0xF);
 | 
				
			||||||
 | 
					    const svint32_t vzero_sv = svdup_n_s32(0);
 | 
				
			||||||
 | 
					    svfloat32_t acc_sum = svdup_n_f32(0);
 | 
				
			||||||
 | 
					    svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    switch (vector_length) {
 | 
				
			||||||
 | 
					        case 128:
 | 
				
			||||||
 | 
					            for (int i = 0; i < nb; ++i) {
 | 
				
			||||||
 | 
					                const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 | 
				
			||||||
 | 
					                svfloat32_t d_broad = svdup_n_f32((float32_t)d);
 | 
				
			||||||
 | 
					                const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 | 
				
			||||||
 | 
					                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                const uint8_t * restrict q2 = x[i].qs;
 | 
				
			||||||
 | 
					                const int8_t  * restrict q8_sv = y[i].qs;
 | 
				
			||||||
 | 
					                const uint8_t * restrict sc = x[i].scales;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
 | 
				
			||||||
 | 
					                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
 | 
				
			||||||
 | 
					                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
 | 
				
			||||||
 | 
					                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
 | 
				
			||||||
 | 
					                const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
 | 
				
			||||||
 | 
					                const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
 | 
				
			||||||
 | 
					                q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svint32_t sumi1 = svdup_n_s32(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
 | 
				
			||||||
 | 
					                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
 | 
				
			||||||
 | 
					                    svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					                    const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    //-------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2 += 32;
 | 
				
			||||||
 | 
					                    const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
 | 
				
			||||||
 | 
					                    const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            *s = svaddv_f32(svptrue_b32(), acc_sum);
 | 
				
			||||||
 | 
					            break;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        case 256:
 | 
				
			||||||
 | 
					        case 512:
 | 
				
			||||||
 | 
					            for (int i = 0; i < nb; ++i) {
 | 
				
			||||||
 | 
					                const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
 | 
				
			||||||
 | 
					                svfloat32_t d_broad = svdup_n_f32((float32_t)d);
 | 
				
			||||||
 | 
					                const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
 | 
				
			||||||
 | 
					                svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                const uint8_t * restrict q2 = x[i].qs;
 | 
				
			||||||
 | 
					                const int8_t  * restrict q8_sv = y[i].qs;
 | 
				
			||||||
 | 
					                const uint8_t * restrict sc = x[i].scales;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
 | 
				
			||||||
 | 
					                const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
 | 
				
			||||||
 | 
					                const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
 | 
				
			||||||
 | 
					                svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
 | 
				
			||||||
 | 
					                const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
 | 
				
			||||||
 | 
					                const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                svint32_t sumi1 = svdup_n_s32(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
 | 
				
			||||||
 | 
					                    svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
 | 
				
			||||||
 | 
					                    svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2 += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
 | 
				
			||||||
 | 
					                    q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
 | 
				
			||||||
 | 
					                    sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
 | 
				
			||||||
 | 
					            break;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        default:
 | 
				
			||||||
 | 
					            assert(false && "Unsupported vector length");
 | 
				
			||||||
 | 
					            break;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#elif __ARM_NEON
 | 
				
			||||||
    const uint8x16_t m3 = vdupq_n_u8(0x3);
 | 
					    const uint8x16_t m3 = vdupq_n_u8(0x3);
 | 
				
			||||||
    const uint8x16_t m4 = vdupq_n_u8(0xF);
 | 
					    const uint8x16_t m4 = vdupq_n_u8(0xF);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user