From 5d8bb900bc7daa84bfa7bb1d25ab7e32394919f3 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 1 Nov 2025 00:52:14 -0500 Subject: [PATCH] vulkan: Fix multi_add invalid descriptor usage (#16899) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 - .../ggml-vulkan/vulkan-shaders/multi_add.comp | 104 ++++++++++++++++-- 2 files changed, 94 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6a46d0889b..8d1a85c969 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4274,8 +4274,6 @@ static vk_device ggml_vk_get_device(size_t idx) { device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) && - vk12_features.runtimeDescriptorArray && - device->vendor_id != VK_VENDOR_ID_INTEL && getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr; device->shader_int64 = device_features2.features.shaderInt64; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 1e8f694a72..10cf5202a4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2 uint rms_partials; } p; -// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498 -// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; -// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; -layout (binding = 0) buffer A {A_TYPE data_a[];} a[]; -layout (binding = 0) buffer D {D_TYPE data_d[];} d[]; - -layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[]; +// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498 +layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0; +layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1; +layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2; +layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3; +layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4; +layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5; +layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6; +layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7; +layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8; +layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9; +layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10; +layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11; +layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0; +layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1; +layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2; +layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3; +layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4; +layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5; +layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6; +layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7; +layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8; +layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9; +layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10; +layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11; +layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0; +layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1; +layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2; +layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3; +layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4; +layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5; +layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6; +layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7; +layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8; +layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9; +layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10; +layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11; layout(constant_id = 0) const uint num_srcs = 2; +FLOAT_TYPE load_a(uint b, uint i) { + switch (b) { + case 0: return FLOAT_TYPE(a0.data_a[i]); + case 1: return FLOAT_TYPE(a1.data_a[i]); + case 2: return FLOAT_TYPE(a2.data_a[i]); + case 3: return FLOAT_TYPE(a3.data_a[i]); + case 4: return FLOAT_TYPE(a4.data_a[i]); + case 5: return FLOAT_TYPE(a5.data_a[i]); + case 6: return FLOAT_TYPE(a6.data_a[i]); + case 7: return FLOAT_TYPE(a7.data_a[i]); + case 8: return FLOAT_TYPE(a8.data_a[i]); + case 9: return FLOAT_TYPE(a9.data_a[i]); + case 10: return FLOAT_TYPE(a10.data_a[i]); + case 11: return FLOAT_TYPE(a11.data_a[i]); + default: return FLOAT_TYPE(0); + } +} + +void store_d(uint b, uint i, FLOAT_TYPE v) { + switch (b) { + case 0: d0.data_d[i] = D_TYPE(v); break; + case 1: d1.data_d[i] = D_TYPE(v); break; + case 2: d2.data_d[i] = D_TYPE(v); break; + case 3: d3.data_d[i] = D_TYPE(v); break; + case 4: d4.data_d[i] = D_TYPE(v); break; + case 5: d5.data_d[i] = D_TYPE(v); break; + case 6: d6.data_d[i] = D_TYPE(v); break; + case 7: d7.data_d[i] = D_TYPE(v); break; + case 8: d8.data_d[i] = D_TYPE(v); break; + case 9: d9.data_d[i] = D_TYPE(v); break; + case 10: d10.data_d[i] = D_TYPE(v); break; + case 11: d11.data_d[i] = D_TYPE(v); break; + default: break; + } +} + +void store_partial(uint b, uint i, float v) { + switch (b) { + case 0: partials0.partial_sums[i] = v; break; + case 1: partials1.partial_sums[i] = v; break; + case 2: partials2.partial_sums[i] = v; break; + case 3: partials3.partial_sums[i] = v; break; + case 4: partials4.partial_sums[i] = v; break; + case 5: partials5.partial_sums[i] = v; break; + case 6: partials6.partial_sums[i] = v; break; + case 7: partials7.partial_sums[i] = v; break; + case 8: partials8.partial_sums[i] = v; break; + case 9: partials9.partial_sums[i] = v; break; + case 10: partials10.partial_sums[i] = v; break; + case 11: partials11.partial_sums[i] = v; break; + default: break; + } +} + uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) { return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0]; } @@ -78,10 +162,10 @@ void main() { FLOAT_TYPE sum = FLOAT_TYPE(0); [[unroll]] for (uint s = 0; s < num_srcs; ++s) { - sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]); + sum += load_a(s, src_idx(s, i00, i01, i02, i03)); } sum_sq += sum*sum; - d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); + store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum); idx += num_threads; } @@ -104,7 +188,7 @@ void main() { } if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { - partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq); } } #endif