mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-13 10:57:15 +00:00
vulkan: support fattn sinks (#15126)
This commit is contained in:
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
layout (binding = 1) readonly buffer B {float data_s[];};
|
||||
layout (binding = 2) writeonly buffer D {float data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint D;
|
||||
uint N;
|
||||
uint ne3;
|
||||
uint k_num;
|
||||
uint sinks;
|
||||
} p;
|
||||
|
||||
shared float tmpsh[BLOCK_SIZE];
|
||||
@@ -73,6 +75,22 @@ void main() {
|
||||
}
|
||||
L = tmpsh[0];
|
||||
|
||||
float sink;
|
||||
if (p.sinks != 0) {
|
||||
sink = data_s[n];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (sink > m_max) {
|
||||
ms = exp(m_max - sink);
|
||||
} else {
|
||||
vs = exp(sink - m_max);
|
||||
}
|
||||
|
||||
L = L*ms + vs;
|
||||
}
|
||||
|
||||
L = 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
@@ -85,6 +103,13 @@ void main() {
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
O += exp(m - m_max) * data_a[o_offset];
|
||||
}
|
||||
if (p.sinks != 0) {
|
||||
if (sink > m_max) {
|
||||
float ms = 1.0f;
|
||||
ms = exp(m_max - sink);
|
||||
O *= ms;
|
||||
}
|
||||
}
|
||||
O *= L;
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user