mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	feat: Use local variable for state recursion
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
		@@ -1788,8 +1788,11 @@ kernel void kernel_ssm_scan_f32_group(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    device const int32_t * ids = (device const int32_t *) src6;
 | 
					    device const int32_t * ids = (device const int32_t *) src6;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
 | 
					    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
 | 
				
			||||||
    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
 | 
					    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
 | 
				
			||||||
 | 
					    const int64_t i = tpitg.x + i1*nc;
 | 
				
			||||||
 | 
					    float s0 = s0_buff[i];
 | 
				
			||||||
 | 
					    float s  = s_buff[i];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
 | 
					    device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
 | 
				
			||||||
    device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
 | 
					    device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
 | 
				
			||||||
@@ -1809,9 +1812,8 @@ kernel void kernel_ssm_scan_f32_group(
 | 
				
			|||||||
        const float x_dt = x[0] * dt_soft_plus;
 | 
					        const float x_dt = x[0] * dt_soft_plus;
 | 
				
			||||||
        const float dA = exp(dt_soft_plus * A[0]);
 | 
					        const float dA = exp(dt_soft_plus * A[0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        const int64_t i = tpitg.x + i1*nc;
 | 
					        const float state = (s0 * dA) + (B[tpitg.x] * x_dt);
 | 
				
			||||||
        const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
 | 
					        s = state;
 | 
				
			||||||
        s[i] = state;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Parallel sum: This relies on the fact that this kernel will be
 | 
					        // Parallel sum: This relies on the fact that this kernel will be
 | 
				
			||||||
        // dispatched with each threadgroup having (d_state, 1, 1) threads which
 | 
					        // dispatched with each threadgroup having (d_state, 1, 1) threads which
 | 
				
			||||||
@@ -1851,6 +1853,9 @@ kernel void kernel_ssm_scan_f32_group(
 | 
				
			|||||||
        // recurse
 | 
					        // recurse
 | 
				
			||||||
        s0 = s;
 | 
					        s0 = s;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Assign the final state to the output buffer
 | 
				
			||||||
 | 
					    s_buff[i] = s;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
kernel void kernel_rwkv_wkv6_f32(
 | 
					kernel void kernel_rwkv_wkv6_f32(
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user