mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-20 12:07:33 +00:00
ggml : fix SSM_SCAN for n_groups > 1 (#15625)
This commit is contained in:
@@ -1983,14 +1983,15 @@ kernel void kernel_ssm_scan_f32(
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const int64_t g = ir / (nh / ng); // repeat_interleave
|
||||
float s0 = s0_buff[i];
|
||||
float s = s_buff[i];
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
||||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
|
||||
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
@@ -2098,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group(
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const int64_t g = ir / (nh / ng); // repeat_interleave
|
||||
float s0 = s0_buff[i];
|
||||
float s = s_buff[i];
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
|
||||
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
|
||||
Reference in New Issue
Block a user