mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: clamp matmul and FA results to the max finite value (#15652)
* vulkan: clamp matmul and FA results to the max finite value * only clamp for fp16
This commit is contained in:
		| @@ -334,6 +334,9 @@ void main() { | |||||||
|     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { | ||||||
|         [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |         [[unroll]] for (uint32_t r = 0; r < Br; ++r) { | ||||||
|             Of[r][d] *= Lfrcp[r]; |             Of[r][d] *= Lfrcp[r]; | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  |             Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); | ||||||
|  | #endif | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -373,6 +373,9 @@ void main() { | |||||||
|     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { | ||||||
|         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { |         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { | ||||||
|             Of[r][d] *= ACC_TYPE(Lfrcp[r]); |             Of[r][d] *= ACC_TYPE(Lfrcp[r]); | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  |             Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); | ||||||
|  | #endif | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -283,6 +283,10 @@ void main() { | |||||||
|  |  | ||||||
|     O = Ldiag*O; |     O = Ldiag*O; | ||||||
|  |  | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  |     [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; |     uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; | ||||||
|  |  | ||||||
|     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O); |     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O); | ||||||
|   | |||||||
| @@ -111,6 +111,10 @@ void main() { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         O *= L; |         O *= L; | ||||||
|  |  | ||||||
|  |         const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); | ||||||
|  |         O = clamp(O, -FLT_MAX, FLT_MAX); | ||||||
|  |  | ||||||
|         data_d[iq3 * D * N + D * n + d] = O; |         data_d[iq3 * D * N + D * n + d] = O; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -891,6 +891,20 @@ void main() { | |||||||
|         barrier(); |         barrier(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  | #ifdef COOPMAT | ||||||
|  |     [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) { | ||||||
|  |         [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) { | ||||||
|  |             sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { | ||||||
|  |         sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     const uint dr = ir * BM + warp_r * WM; |     const uint dr = ir * BM + warp_r * WM; | ||||||
|     const uint dc = ic * BN + warp_c * WN; |     const uint dc = ic * BN + warp_c * WN; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -349,6 +349,10 @@ void main() { | |||||||
|                 sum = coopMatMulAdd(mat_a, mat_b, sum); |                 sum = coopMatMulAdd(mat_a, mat_b, sum); | ||||||
|                 block_k += BK; |                 block_k += BK; | ||||||
|             } |             } | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  |             [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum); |             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum); | ||||||
|  |  | ||||||
|             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); |             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); | ||||||
| @@ -388,6 +392,10 @@ void main() { | |||||||
|                 sum = coopMatMulAdd(mat_a, mat_b, sum); |                 sum = coopMatMulAdd(mat_a, mat_b, sum); | ||||||
|                 block_k += BK; |                 block_k += BK; | ||||||
|             } |             } | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  |             [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum); |             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum); | ||||||
|  |  | ||||||
|             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); |             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); | ||||||
| @@ -428,6 +436,10 @@ void main() { | |||||||
|                 sum = coopMatMulAdd(mat_a, mat_b, sum); |                 sum = coopMatMulAdd(mat_a, mat_b, sum); | ||||||
|                 block_k += BK; |                 block_k += BK; | ||||||
|             } |             } | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  |             [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum); |             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum); | ||||||
|  |  | ||||||
|             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); |             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); | ||||||
| @@ -485,6 +497,9 @@ void main() { | |||||||
|                 sum = coopMatMulAdd(mat_a, mat_b, sum); |                 sum = coopMatMulAdd(mat_a, mat_b, sum); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  | #if defined(ACC_TYPE_MAX) | ||||||
|  |         [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|         // Convert from ACC_TYPE to D_TYPE |         // Convert from ACC_TYPE to D_TYPE | ||||||
|         coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d; |         coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d; | ||||||
|   | |||||||
| @@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; |     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||||
|  |     if (f16acc) { | ||||||
|  |         base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     if (coopmat) { |     if (coopmat) { | ||||||
|         base_dict["COOPMAT"] = "1"; |         base_dict["COOPMAT"] = "1"; | ||||||
| @@ -437,8 +440,12 @@ void process_shaders() { | |||||||
|  |  | ||||||
|     // flash attention |     // flash attention | ||||||
|     for (const auto& f16acc : {false, true}) { |     for (const auto& f16acc : {false, true}) { | ||||||
|         std::string acctype = f16acc ? "float16_t" : "float"; |         std::map<std::string, std::string> fa_base_dict = base_dict; | ||||||
|         std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; |         fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||||
|  |         fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; | ||||||
|  |         if (f16acc) { | ||||||
|  |             fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         for (const auto& tname : type_names) { |         for (const auto& tname : type_names) { | ||||||
|             if (tname == "f32") { |             if (tname == "f32") { | ||||||
| @@ -449,30 +456,30 @@ void process_shaders() { | |||||||
| #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) | #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) | ||||||
|             if (tname == "f16") { |             if (tname == "f16") { | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); |                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); | ||||||
|             } else { |             } else { | ||||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); |                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); |                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); | ||||||
|             } |             } | ||||||
| #endif | #endif | ||||||
| #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) | #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) | ||||||
|             if (tname == "f16") { |             if (tname == "f16") { | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", | ||||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); |                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); | ||||||
|             } else if (tname == "q4_0" || tname == "q8_0") { |             } else if (tname == "q4_0" || tname == "q8_0") { | ||||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); |                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", | ||||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); |                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); | ||||||
|             } |             } | ||||||
| #endif | #endif | ||||||
|             if (tname == "f16") { |             if (tname == "f16") { | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", | ||||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); |                     merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); | ||||||
|             } else if (tname == "q4_0" || tname == "q8_0") { |             } else if (tname == "q4_0" || tname == "q8_0") { | ||||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); |                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", | ||||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); |                     merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz