ggml : fix SSM_SCAN for n_groups > 1 (#15625)

This commit is contained in:
compilade
2025-08-28 10:11:36 -04:00
committed by GitHub
parent c8d0d14e77
commit 73804145ab
3 changed files with 18 additions and 15 deletions

View File

@@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));