mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-15 11:17:31 +00:00
ggml : fix SSM_SCAN for n_groups > 1 (#15625)
This commit is contained in:
@@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
|
||||
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
|
||||
const int seq_idx = blockIdx.y;
|
||||
|
||||
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
|
||||
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
|
||||
|
||||
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
|
||||
|
||||
Reference in New Issue
Block a user