mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	metal : add missing args for nb references in ssm_scan_f32_group
This commit is contained in:
		@@ -1350,16 +1350,16 @@ kernel void kernel_ssm_scan_f32_group(
 | 
			
		||||
 | 
			
		||||
    device const int32_t * ids = (device const int32_t *) src6;
 | 
			
		||||
 | 
			
		||||
    device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
 | 
			
		||||
    device       float * s  = (device       float *) ((device       char *) dst  + ir*nb02 +      i3*nb03 + s_off);
 | 
			
		||||
    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
 | 
			
		||||
    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
 | 
			
		||||
 | 
			
		||||
    for (int64_t i2 = 0; i2 < n_t; ++i2) {
 | 
			
		||||
        device const float * x  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
 | 
			
		||||
        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns}
 | 
			
		||||
        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh}
 | 
			
		||||
        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
 | 
			
		||||
        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
 | 
			
		||||
        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
 | 
			
		||||
        device const float * x  = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
 | 
			
		||||
        device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
 | 
			
		||||
        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
 | 
			
		||||
        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
 | 
			
		||||
        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
 | 
			
		||||
        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns}
 | 
			
		||||
 | 
			
		||||
        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
 | 
			
		||||
        const float x_dt = x[0] * dt_soft_plus;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user