mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
vulkan: support fattn sinks (#15126)
This commit is contained in:
@@ -305,6 +305,27 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (sink > Mf[r]) {
|
||||
ms = exp(Mf[r] - sink);
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] *= ms;
|
||||
}
|
||||
} else {
|
||||
vs = exp(sink - Mf[r]);
|
||||
}
|
||||
|
||||
Lf[r] = Lf[r]*ms + vs;
|
||||
}
|
||||
}
|
||||
|
||||
float Lfrcp[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
|
||||
@@ -50,10 +50,13 @@ layout (push_constant) uniform parameter {
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
#define SINK_ENABLE_BIT (1<<24)
|
||||
#define MASK_ENABLE_BIT (1<<16)
|
||||
#define N_LOG2_MASK 0xFFFF
|
||||
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
layout (binding = 4) readonly buffer S {float data_s[];};
|
||||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
@@ -111,6 +114,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
// Load the sink value, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
return ACC_TYPE(data_s[h]);
|
||||
}
|
||||
|
||||
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
||||
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
||||
q_stride, k_stride, v_stride, m_stride;
|
||||
|
||||
@@ -329,6 +329,27 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (sink > Mf[r]) {
|
||||
ms = exp(Mf[r] - sink);
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] *= ACC_TYPE(ms);
|
||||
}
|
||||
} else {
|
||||
vs = exp(sink - Mf[r]);
|
||||
}
|
||||
|
||||
Lf[r] = Lf[r]*ms + vs;
|
||||
}
|
||||
}
|
||||
|
||||
float Lfrcp[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
|
||||
@@ -248,6 +248,34 @@ void main() {
|
||||
// resize L by using smear/reduce
|
||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
|
||||
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
|
||||
|
||||
// resize M by using smear/reduce
|
||||
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
// O, Ldiag, Mr all have the same type so all element locations match
|
||||
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
|
||||
ACC_TYPE sink = S[i];
|
||||
|
||||
ACC_TYPE ms = ACC_TYPE(1.0f);
|
||||
ACC_TYPE vs = ACC_TYPE(1.0f);
|
||||
|
||||
if (sink > Mr[i]) {
|
||||
ms = exp(Mr[i] - sink);
|
||||
|
||||
O[i] *= ms;
|
||||
} else {
|
||||
vs = exp(sink - Mr[i]);
|
||||
}
|
||||
|
||||
Ldiag[i] = Ldiag[i]*ms + vs;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (int k = 0; k < Ldiag.length(); ++k) {
|
||||
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
||||
|
||||
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
layout (binding = 1) readonly buffer B {float data_s[];};
|
||||
layout (binding = 2) writeonly buffer D {float data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint D;
|
||||
uint N;
|
||||
uint ne3;
|
||||
uint k_num;
|
||||
uint sinks;
|
||||
} p;
|
||||
|
||||
shared float tmpsh[BLOCK_SIZE];
|
||||
@@ -73,6 +75,22 @@ void main() {
|
||||
}
|
||||
L = tmpsh[0];
|
||||
|
||||
float sink;
|
||||
if (p.sinks != 0) {
|
||||
sink = data_s[n];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (sink > m_max) {
|
||||
ms = exp(m_max - sink);
|
||||
} else {
|
||||
vs = exp(sink - m_max);
|
||||
}
|
||||
|
||||
L = L*ms + vs;
|
||||
}
|
||||
|
||||
L = 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
@@ -85,6 +103,13 @@ void main() {
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
O += exp(m - m_max) * data_a[o_offset];
|
||||
}
|
||||
if (p.sinks != 0) {
|
||||
if (sink > m_max) {
|
||||
float ms = 1.0f;
|
||||
ms = exp(m_max - sink);
|
||||
O *= ms;
|
||||
}
|
||||
}
|
||||
O *= L;
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user