vulkan: fix shmem overrun in mmq id shader (#16873)

* vulkan: fix shmem overrun in mmq id shader

* metal : fix mul_mm_id

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Ruben Ortlam
2025-10-31 08:14:49 +01:00
committed by GitHub
parent 13002a0896
commit d2a2673dd1
4 changed files with 9 additions and 2 deletions

View File

@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256]; char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
snprintf(name, 256, "%s", base); snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) { if (res) {

View File

@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl" #include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID
#define BK_STEP 1
#else
#ifndef BK_STEP #ifndef BK_STEP
#define BK_STEP 4 #define BK_STEP 4
#endif #endif
#endif
// Shared memory cache // Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP]; shared block_a_cache buf_a[BM * BK_STEP];

View File

@@ -27,7 +27,7 @@ struct block_a_cache {
#elif defined(DATA_A_Q8_0) #elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1 #define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2 // AMD likes 4, Intel likes 1 and Nvidia likes 2
#define BK_STEP 1 // #define BK_STEP 1
struct block_a_cache { struct block_a_cache {
int32_t qs[32/4]; int32_t qs[32/4];
FLOAT_TYPE dm; FLOAT_TYPE dm;

View File

@@ -6880,6 +6880,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1)); test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
for (ggml_type type_a : base_types) { for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) { for (int n_mats : {4, 8}) {