vulkan: apply MUL_MAT_ID subgroup optimization to non-coopmat devices (#15524)

* vulkan: use subgroup function for mul_mat_id shader even without coopmat

* vulkan: fix compile warnings

* vulkan: properly check for subgroup size control and require full subgroups for subgroup mul_mat_id

* vulkan: disable subgroup mul_mat_id on devices with subgroups < 16
This commit is contained in:
Ruben Ortlam
2025-08-24 19:36:36 +02:00
committed by GitHub
parent b730706a49
commit 043fb27d38
3 changed files with 282 additions and 195 deletions

View File

@@ -68,6 +68,12 @@ const std::vector<std::string> type_names = {
"bf16",
};
enum MatMulIdType {
NONE,
DEFAULT,
SUBGROUP,
};
namespace {
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
@@ -293,7 +299,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
}
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@@ -303,9 +309,13 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
};
std::string shader_name = "matmul";
if (matmul_id) {
if (matmul_id_type == MatMulIdType::DEFAULT) {
base_dict["MUL_MAT_ID"] = "1";
shader_name = "matmul_id";
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
base_dict["MUL_MAT_ID"] = "1";
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
shader_name = "matmul_id_subgroup";
}
if (fp16) {
@@ -389,7 +399,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -401,26 +411,28 @@ void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
// matmul
for (const auto& matmul_id : {false, true}) {
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
// No coopmats
// fp32
matmul_shaders(false, matmul_id, false, false, false);
matmul_shaders(false, matmul_id_type, false, false, false);
// fp16, fp32acc and fp16acc
matmul_shaders(true, matmul_id, false, false, false);
matmul_shaders(true, matmul_id, false, false, true);
matmul_shaders(true, matmul_id_type, false, false, false);
matmul_shaders(true, matmul_id_type, false, false, true);
if (matmul_id_type != MatMulIdType::DEFAULT) {
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id, true, false, false);
matmul_shaders(true, matmul_id, true, false, true);
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id_type, true, false, false);
matmul_shaders(true, matmul_id_type, true, false, true);
#endif
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// Coopmat2, fp32acc and fp16acc
matmul_shaders(true, matmul_id, false, true, false);
matmul_shaders(true, matmul_id, false, true, true);
// Coopmat2, fp32acc and fp16acc
matmul_shaders(true, matmul_id_type, false, true, false);
matmul_shaders(true, matmul_id_type, false, true, true);
#endif
}
}
// flash attention