vulkan: support softmax/FA batch and broadcast (#14449)

This commit is contained in:
Jeff Bolz
2025-07-01 03:32:56 -05:00
committed by Georgi Gerganov
parent ec68e84c32
commit 8875523eb3
7 changed files with 80 additions and 44 deletions

View File

@@ -123,6 +123,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -181,7 +185,7 @@ void main() {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
barrier();
@@ -300,7 +304,7 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index;
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
@@ -312,7 +316,7 @@ void main() {
}
}
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -334,7 +338,7 @@ void main() {
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1;
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {