mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
vulkan: vec dot matrix multiplication fix (#16151)
* vulkan: fix matrix multiplication index calculation for odd m/n and odd k in combination with batching * add odd m/n + odd k test with batching
This commit is contained in:
@@ -31,10 +31,22 @@
|
|||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
|
|
||||||
#ifndef LOAD_VEC_A
|
#ifndef LOAD_VEC_A
|
||||||
#define LOAD_VEC_A 2
|
#define LOAD_VEC_A 1
|
||||||
#endif
|
#endif
|
||||||
#ifndef LOAD_VEC_B
|
#ifndef LOAD_VEC_B
|
||||||
#define LOAD_VEC_B 2
|
#define LOAD_VEC_B 1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Load 2 values at once without affecting index calculations through LOAD_VEC
|
||||||
|
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
|
||||||
|
#define LOAD_VEC_BATCH_A 2
|
||||||
|
#else
|
||||||
|
#define LOAD_VEC_BATCH_A 1
|
||||||
|
#endif
|
||||||
|
#if !defined(ALIGNED)
|
||||||
|
#define LOAD_VEC_BATCH_B 2
|
||||||
|
#else
|
||||||
|
#define LOAD_VEC_BATCH_B 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if !defined(TO_FLOAT_TYPE)
|
#if !defined(TO_FLOAT_TYPE)
|
||||||
@@ -236,13 +248,13 @@ void main() {
|
|||||||
const uint warp_r = warp_i % (BM / WM);
|
const uint warp_r = warp_i % (BM / WM);
|
||||||
const uint warp_c = warp_i / (BM / WM);
|
const uint warp_c = warp_i / (BM / WM);
|
||||||
|
|
||||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
|
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
|
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||||
|
|
||||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
|
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
|
||||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
||||||
buf_a[buf_idx ] = aa.xy;
|
buf_a[buf_idx ] = aa.xy;
|
||||||
buf_a[buf_idx + 1] = aa.zw;
|
buf_a[buf_idx + 1] = aa.zw;
|
||||||
#else // LOAD_VEC_A == 2
|
#else // LOAD_VEC_BATCH_A == 2
|
||||||
const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
|
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
|
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
|
||||||
@@ -33,8 +33,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
||||||
buf_a[buf_idx ] = aa.xy;
|
buf_a[buf_idx ] = aa.xy;
|
||||||
buf_a[buf_idx + 1] = aa.zw;
|
buf_a[buf_idx + 1] = aa.zw;
|
||||||
#else // LOAD_VEC_A == 2
|
#else // LOAD_VEC_BATCH_A == 2
|
||||||
const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
|
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
|
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
|
||||||
@@ -500,8 +500,8 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
|||||||
#endif
|
#endif
|
||||||
buf_b[buf_idx + 0] = bb.xy;
|
buf_b[buf_idx + 0] = bb.xy;
|
||||||
buf_b[buf_idx + 1] = bb.zw;
|
buf_b[buf_idx + 1] = bb.zw;
|
||||||
#else // LOAD_VEC_B == 2
|
#else // LOAD_VEC_BATCH_B == 2
|
||||||
const uint idx = pos_b * 2 + col * p.stride_b + row * 2;
|
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||||
@@ -536,17 +536,17 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
|||||||
#endif
|
#endif
|
||||||
buf_b[buf_idx + 0] = bb.xy;
|
buf_b[buf_idx + 0] = bb.xy;
|
||||||
buf_b[buf_idx + 1] = bb.zw;
|
buf_b[buf_idx + 1] = bb.zw;
|
||||||
#else // LOAD_VEC_B == 2
|
#else // LOAD_VEC_BATCH_B == 2
|
||||||
const uint row_i = ic * BN + col;
|
const uint row_i = ic * BN + col;
|
||||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||||
const u16vec2 row_idx = row_ids[col];
|
const u16vec2 row_idx = row_ids[col];
|
||||||
const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||||
const u16vec2 row_idx = row_ids[col];
|
const u16vec2 row_idx = row_ids[col];
|
||||||
const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||||
} else {
|
} else {
|
||||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||||
|
|||||||
@@ -454,7 +454,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|||||||
|
|
||||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||||
std::string load_vec_a_unaligned = coopmat2 ? "1" : (tname == "f32" || tname == "f16" || tname == "bf16") ? "2" : load_vec_quant;
|
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||||
// For aligned matmul loads
|
// For aligned matmul loads
|
||||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||||
|
|
||||||
|
|||||||
@@ -6231,6 +6231,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
|
||||||
|
|
||||||
for (auto bs2 : {1,3}) {
|
for (auto bs2 : {1,3}) {
|
||||||
for (auto bs : {1,2,4,8}) {
|
for (auto bs : {1,2,4,8}) {
|
||||||
|
|||||||
Reference in New Issue
Block a user