vulkan: support arbitrary KV dimension in flash attention (#16160)

The "Clamp" spec constant is already based on whether KV is a multiple of Bc,
so use that to control whether bounds checking is performed. Add bounds checking
to the scalar and coopmat1 paths. Coopmat2 didn't need any changes (the K/V
tensors are already optionally clamped, nothing else needed to be changed).
This commit is contained in:
Jeff Bolz
2025-09-27 16:43:39 -04:00
committed by GitHub
parent 8656f5de68
commit e6d65fb02d
3 changed files with 38 additions and 9 deletions

View File

@@ -117,6 +117,9 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
@@ -155,7 +158,11 @@ void main() {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc; uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) { if (idx + tid < Bc * Br) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); if (!KV_bounds_check || j * Bc + c < KV) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
} }
} }
barrier(); barrier();
@@ -172,8 +179,11 @@ void main() {
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = Sf[r][0]; rowmaxf[r] = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
} }
Moldf[r] = Mf[r]; Moldf[r] = Mf[r];
@@ -190,6 +200,9 @@ void main() {
// Compute sum across row of P // Compute sum across row of P
rowsumf[r] = 0.0; rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf[r] += Pf[r][c]; rowsumf[r] += Pf[r][c];
} }
@@ -203,6 +216,9 @@ void main() {
} }
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);

View File

@@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16;
const uint32_t HSK_pad = (HSK + 15) & ~15; const uint32_t HSK_pad = (HSK + 15) & ~15;
const uint32_t HSV_pad = (HSV + 15) & ~15; const uint32_t HSV_pad = (HSV + 15) & ~15;
const bool KV_bounds_check = Clamp != 0;
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint32_t N; uint32_t N;
uint32_t KV; uint32_t KV;

View File

@@ -152,14 +152,17 @@ void main() {
uint32_t d = (idx + tid) % (HSK / 4); uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) { if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE); uint iqs = (coord % BLOCK_SIZE);
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else #else
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif #endif
}
ksh[c * kshstride + d] = K_Tf; ksh[c * kshstride + d] = K_Tf;
} }
@@ -202,7 +205,9 @@ void main() {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc; uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); if (!KV_bounds_check || j * Bc + c < KV) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
} }
} }
barrier(); barrier();
@@ -210,8 +215,11 @@ void main() {
float eMf[rows_per_thread]; float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
} }
float Moldf = Mf[r]; float Moldf = Mf[r];
@@ -233,6 +241,9 @@ void main() {
} }
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
float Pf[rows_per_thread]; float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);