mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: clamp matmul and FA results to the max finite value (#15652)
* vulkan: clamp matmul and FA results to the max finite value * only clamp for fp16
This commit is contained in:
		| @@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c | ||||
|     } | ||||
|  | ||||
|     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||
|     if (f16acc) { | ||||
|         base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; | ||||
|     } | ||||
|  | ||||
|     if (coopmat) { | ||||
|         base_dict["COOPMAT"] = "1"; | ||||
| @@ -437,8 +440,12 @@ void process_shaders() { | ||||
|  | ||||
|     // flash attention | ||||
|     for (const auto& f16acc : {false, true}) { | ||||
|         std::string acctype = f16acc ? "float16_t" : "float"; | ||||
|         std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; | ||||
|         std::map<std::string, std::string> fa_base_dict = base_dict; | ||||
|         fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||
|         fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; | ||||
|         if (f16acc) { | ||||
|             fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; | ||||
|         } | ||||
|  | ||||
|         for (const auto& tname : type_names) { | ||||
|             if (tname == "f32") { | ||||
| @@ -449,30 +456,30 @@ void process_shaders() { | ||||
| #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) | ||||
|             if (tname == "f16") { | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); | ||||
|                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); | ||||
|             } else { | ||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); | ||||
|                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); | ||||
|             } | ||||
| #endif | ||||
| #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) | ||||
|             if (tname == "f16") { | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", | ||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); | ||||
|                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); | ||||
|             } else if (tname == "q4_0" || tname == "q8_0") { | ||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", | ||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); | ||||
|                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); | ||||
|             } | ||||
| #endif | ||||
|             if (tname == "f16") { | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", | ||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); | ||||
|                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); | ||||
|             } else if (tname == "q4_0" || tname == "q8_0") { | ||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", | ||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); | ||||
|                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz