mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-11 10:36:54 +00:00
When using group query attention, we have one workgroup per KV batch and this can be very few workgroups (e.g. just 8 in some models). Enable split_k to spread the work across SMs. This helps a lot when the KV cache is large.
60 lines
1.5 KiB
Plaintext
60 lines
1.5 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
|
|
#define BLOCK_SIZE 32
|
|
|
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
|
layout (binding = 1) writeonly buffer D {float data_d[];};
|
|
|
|
layout (push_constant) uniform parameter {
|
|
uint D;
|
|
uint N;
|
|
uint k_num;
|
|
} p;
|
|
|
|
void main() {
|
|
// Each workgroup handles a row
|
|
const uint n = gl_WorkGroupID.x;
|
|
const uint tid = gl_LocalInvocationID.x;
|
|
|
|
uint D = p.D;
|
|
uint N = p.N;
|
|
uint k_num = p.k_num;
|
|
|
|
uint l_offset = D * N * k_num + n;
|
|
uint m_offset = D * N * k_num + N + n;
|
|
uint lm_stride = N * 2;
|
|
|
|
// Compute the max m value for the row
|
|
float m_max = -1.0/0.0;
|
|
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
|
float m = data_a[m_offset + k * lm_stride];
|
|
m_max = max(m_max, m);
|
|
}
|
|
|
|
// Compute L based on m_max
|
|
float L = 0;
|
|
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
|
float l = data_a[l_offset + k * lm_stride];
|
|
float m = data_a[m_offset + k * lm_stride];
|
|
L += exp(m - m_max) * l;
|
|
}
|
|
|
|
L = 1.0 / L;
|
|
|
|
// Scale and sum the O contributions based on m_max and store the result to memory
|
|
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
|
float O = 0.0;
|
|
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
|
uint o_offset = D * N * k + D * n + d;
|
|
float m = data_a[m_offset + k * lm_stride];
|
|
O += exp(m - m_max) * data_a[o_offset];
|
|
}
|
|
O *= L;
|
|
data_d[D * n + d] = O;
|
|
}
|
|
}
|