mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
vulkan: support softmax/FA batch and broadcast (#14449)
This commit is contained in:
committed by
Georgi Gerganov
parent
ec68e84c32
commit
8875523eb3
@@ -99,6 +99,10 @@ void main() {
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
uint32_t m_offset = 0;
|
||||
if (p.nem2 != 1) {
|
||||
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
@@ -150,7 +154,7 @@ void main() {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
@@ -277,7 +281,7 @@ void main() {
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
@@ -289,7 +293,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
@@ -311,7 +315,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
|
||||
Reference in New Issue
Block a user