diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 38a4d07d03..3cb24412d5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -31,10 +31,22 @@ #include "types.comp" #ifndef LOAD_VEC_A -#define LOAD_VEC_A 2 +#define LOAD_VEC_A 1 #endif #ifndef LOAD_VEC_B -#define LOAD_VEC_B 2 +#define LOAD_VEC_B 1 +#endif + +// Load 2 values at once without affecting index calculations through LOAD_VEC +#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) +#define LOAD_VEC_BATCH_A 2 +#else +#define LOAD_VEC_BATCH_A 1 +#endif +#if !defined(ALIGNED) +#define LOAD_VEC_BATCH_B 2 +#else +#define LOAD_VEC_BATCH_B 1 #endif #if !defined(TO_FLOAT_TYPE) @@ -236,13 +248,13 @@ void main() { const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); - const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); - const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); - const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); - const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); - const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; - const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; #ifdef MUL_MAT_ID #ifdef MUL_MAT_ID_USE_SUBGROUPS diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp index 69d0e64c35..0ebfbd6462 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp @@ -14,8 +14,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_A == 2 - const uint idx = pos_a * 2 + col * p.stride_a + row * 2; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], @@ -33,8 +33,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_A == 2 - const uint idx = pos_a * 2 + col * p.stride_a + row * 2; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), @@ -500,8 +500,8 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_B == 2 - const uint idx = pos_b * 2 + col * p.stride_b + row * 2; +#else // LOAD_VEC_BATCH_B == 2 + const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), @@ -536,17 +536,17 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_B == 2 +#else // LOAD_VEC_BATCH_B == 2 const uint row_i = ic * BN + col; const uint buf_idx = col * SHMEM_STRIDE + row; if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), TO_FLOAT_TYPE(data_b[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 74a4794d34..8e2507ad8e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -454,7 +454,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = coopmat2 ? "1" : (tname == "f32" || tname == "f16" || tname == "bf16") ? "2" : load_vec_quant; + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 507b691dc9..ef6594ea59 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6231,6 +6231,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3)); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1})); for (auto bs2 : {1,3}) { for (auto bs : {1,2,4,8}) {