From ba1ceb34566c889a1fc500efa79799ffed25d9b0 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 15 Jul 2025 14:51:09 -0500 Subject: [PATCH] vulkan: fix noncontig check for mat_mul_id splitting (#14683) * vulkan: fix noncontig check for mat_mul_id splitting Remove supports_op check for > 4096 (splitting fixes this) * vulkan: fix batched matmul dequant for Q*_K --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +----- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp | 2 +- 6 files changed, 6 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9f5646bf29..3019a545d5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4922,7 +4922,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { return tensor->nb[0] == ggml_type_size(tensor->type) && tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; + (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); } static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { @@ -10356,10 +10356,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm // If there's not enough shared memory for row_ids and the result tile, fallback to CPU return false; } - // Check against size of shared memory variable - if (op->src[2]->ne[0] > 4096) { - return false; - } } switch (src0_type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index 157154af3a..d4e4e6bae6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp index c17dd0d999..3661f771c7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = uint(gl_WorkGroupID.x * 256 + wgy); - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 987f113a35..1370db3654 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 6db5403b66..3f3b839e11 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp index 0b91317550..9cf34256e8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } const uint tid = gl_LocalInvocationID.x;