vulkan: clamp matmul and FA results to the max finite value (#15652)

* vulkan: clamp matmul and FA results to the max finite value

* only clamp for fp16
This commit is contained in:
Jeff Bolz
2025-08-31 01:27:57 -05:00
committed by GitHub
parent 4d74393bcc
commit 94e82c7ead
7 changed files with 58 additions and 8 deletions

View File

@@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
if (f16acc) {
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
}
if (coopmat) {
base_dict["COOPMAT"] = "1";
@@ -437,8 +440,12 @@ void process_shaders() {
// flash attention
for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float";
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
if (f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
}
for (const auto& tname : type_names) {
if (tname == "f32") {
@@ -449,30 +456,30 @@ void process_shaders() {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
}
}
}