mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-13 10:57:15 +00:00
vulkan: support softmax/FA batch and broadcast (#14449)
This commit is contained in:
committed by
Georgi Gerganov
parent
ec68e84c32
commit
8875523eb3
@@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
layout (push_constant) uniform parameter {
|
||||
uint D;
|
||||
uint N;
|
||||
uint ne3;
|
||||
uint k_num;
|
||||
} p;
|
||||
|
||||
@@ -19,13 +20,14 @@ void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint iq3 = gl_WorkGroupID.z;
|
||||
|
||||
uint D = p.D;
|
||||
uint N = p.N;
|
||||
uint k_num = p.k_num;
|
||||
|
||||
uint l_offset = D * N * k_num + n;
|
||||
uint m_offset = D * N * k_num + N + n;
|
||||
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
||||
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
||||
uint lm_stride = N * 2;
|
||||
|
||||
// Compute the max m value for the row
|
||||
@@ -49,11 +51,11 @@ void main() {
|
||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * k + D * n + d;
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
O += exp(m - m_max) * data_a[o_offset];
|
||||
}
|
||||
O *= L;
|
||||
data_d[D * n + d] = O;
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user