vulkan: Add bfloat16 support (#12554)

* vulkan: Add bfloat16 support

This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16.
The extension is required for coopmat multiply support, but matrix-vector
multiply trivially promotes bf16 to fp32 and doesn't require the extension.
The copy/get_rows shaders also don't require the extension.

It's probably possible to fall back to non-coopmat and promote to fp32 when
the extension isn't supported, but this change doesn't do that.

The coopmat support also requires a glslc that supports the extension, which
currently requires a custom build.

* vulkan: Support bf16 tensors without the bf16 extension or coopmat support

Compile a variant of the scalar mul_mm shader that will promote the bf16
values to float, and use that when either the bf16 extension or the coopmat
extensions aren't available.

* vulkan: bfloat16 fixes (really works without bfloat16 support now)

* vulkan: fix spirv-val failure and reenable -O
This commit is contained in:
Jeff Bolz
2025-05-01 13:49:39 -05:00
committed by GitHub
parent fc727bcdd5
commit 79f26e9e12
13 changed files with 368 additions and 67 deletions

View File

@@ -10,6 +10,10 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#if defined(DATA_A_BF16) && defined(COOPMAT)
#extension GL_EXT_bfloat16 : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
@@ -29,6 +33,10 @@
#define LOAD_VEC_B 1
#endif
#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
#endif
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -202,8 +210,8 @@ void main() {
#endif
#ifdef COOPMAT
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
@@ -248,6 +256,21 @@ void main() {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
#else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
@@ -695,13 +718,13 @@ void main() {
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}
@@ -709,7 +732,7 @@ void main() {
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}