mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml : fix mamba2 ssm scan when compiled with SVE
This commit is contained in:
		@@ -7664,6 +7664,37 @@ static void ggml_compute_forward_ssm_scan_f32(
 | 
				
			|||||||
                        const float x_dt = x[ii] * dt_soft_plus;
 | 
					                        const float x_dt = x[ii] * dt_soft_plus;
 | 
				
			||||||
                        float sumf = 0.0f;
 | 
					                        float sumf = 0.0f;
 | 
				
			||||||
#if defined(GGML_SIMD)
 | 
					#if defined(GGML_SIMD)
 | 
				
			||||||
 | 
					    #if defined(__ARM_FEATURE_SVE)
 | 
				
			||||||
 | 
					                        const int ggml_f32_epr = svcntw();
 | 
				
			||||||
 | 
					                        const int ggml_f32_step = 1 * ggml_f32_epr;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        const int np = (nc & ~(ggml_f32_step - 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
 | 
				
			||||||
 | 
					                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        for (int i = 0; i < np; i += ggml_f32_step) {
 | 
				
			||||||
 | 
					                            // TODO: maybe unroll more?
 | 
				
			||||||
 | 
					                            for (int j = 0; j < 1; j++) {
 | 
				
			||||||
 | 
					                                GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
 | 
				
			||||||
 | 
					                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
 | 
				
			||||||
 | 
					                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                t0 = GGML_F32_VEC_MUL(t0, adA);
 | 
				
			||||||
 | 
					                                t1 = GGML_F32_VEC_MUL(t1, axdt);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                t0 = GGML_F32_VEC_ADD(t0, t1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                sum = GGML_F32_VEC_FMA(sum, t0, t2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        sumf = GGML_F32xt_REDUCE_ONE(sum);
 | 
				
			||||||
 | 
					    #else
 | 
				
			||||||
                        const int np = (nc & ~(GGML_F32_STEP - 1));
 | 
					                        const int np = (nc & ~(GGML_F32_STEP - 1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
 | 
					                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
 | 
				
			||||||
@@ -7694,6 +7725,7 @@ static void ggml_compute_forward_ssm_scan_f32(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                        // reduce sum0..sum3 to sum0
 | 
					                        // reduce sum0..sum3 to sum0
 | 
				
			||||||
                        GGML_F32_VEC_REDUCE(sumf, sum);
 | 
					                        GGML_F32_VEC_REDUCE(sumf, sum);
 | 
				
			||||||
 | 
					    #endif
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
                        const int np = 0;
 | 
					                        const int np = 0;
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
@@ -7722,7 +7754,7 @@ static void ggml_compute_forward_ssm_scan_f32(
 | 
				
			|||||||
                    for (int i1 = 0; i1 < nr; ++i1) {
 | 
					                    for (int i1 = 0; i1 < nr; ++i1) {
 | 
				
			||||||
                        const int ii = i1 + h*nr;
 | 
					                        const int ii = i1 + h*nr;
 | 
				
			||||||
                        const float x_dt = x[ii] * dt_soft_plus;
 | 
					                        const float x_dt = x[ii] * dt_soft_plus;
 | 
				
			||||||
#ifdef __ARM_FEATURE_SVE
 | 
					#if defined(__ARM_FEATURE_SVE)
 | 
				
			||||||
                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
 | 
					                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
 | 
				
			||||||
                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
 | 
					                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
 | 
				
			||||||
                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
 | 
					                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user