mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: scalar flash attention implementation (#13324)
* vulkan: scalar flash attention implementation * vulkan: always use fp32 for scalar flash attention * vulkan: use vector loads in scalar flash attention shader * vulkan: remove PV matrix, helps with register usage * vulkan: reduce register usage in scalar FA, but perf may be slightly worse * vulkan: load each Q value once. optimize O reduction. more tuning * vulkan: support q4_0/q8_0 KV in scalar FA * CI: increase timeout to accommodate newly-supported tests * vulkan: for scalar FA, select between 1 and 8 rows * vulkan: avoid using Float16 capability in scalar FA
This commit is contained in:
		| @@ -275,6 +275,7 @@ struct vk_device_struct { | ||||
|     bool prefer_host_memory; | ||||
|     bool float_controls_rte_fp16; | ||||
|     bool subgroup_add; | ||||
|     bool subgroup_shuffle; | ||||
|  | ||||
|     bool integer_dot_product; | ||||
|  | ||||
| @@ -402,12 +403,20 @@ struct vk_device_struct { | ||||
|     vk_pipeline pipeline_conv2d_dw_cwhn_f32; | ||||
|  | ||||
|     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; | ||||
|  | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; | ||||
|  | ||||
|     vk_pipeline pipeline_flash_attn_split_k_reduce; | ||||
|  | ||||
|     std::unordered_map<std::string, vk_pipeline_ref> pipelines; | ||||
| @@ -1581,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events | ||||
|  | ||||
| // number of rows/cols for flash attention shader | ||||
| static constexpr uint32_t flash_attention_num_small_rows = 32; | ||||
| static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { | ||||
| static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; | ||||
| static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; | ||||
|  | ||||
| static uint32_t get_fa_num_small_rows(bool scalar) { | ||||
|     return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; | ||||
| } | ||||
|  | ||||
| static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { | ||||
|     GGML_UNUSED(clamp); | ||||
|  | ||||
|     if (scalar) { | ||||
|         if (small_rows) { | ||||
|             return {scalar_flash_attention_num_small_rows, 64}; | ||||
|         } else { | ||||
|             return {scalar_flash_attention_num_large_rows, 32}; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // small rows, large cols | ||||
|     if (small_rows) { | ||||
|         return {flash_attention_num_small_rows, 64}; | ||||
|         return {get_fa_num_small_rows(scalar), 32}; | ||||
|     } | ||||
|  | ||||
|     // small cols to reduce register count | ||||
|     if (ggml_is_quantized(type) || D == 256) { | ||||
|         return {64, 32}; | ||||
| @@ -1882,65 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|                                       parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); | ||||
|     }; | ||||
|  | ||||
|     auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { | ||||
|         return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; | ||||
|     }; | ||||
|  | ||||
|     auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { | ||||
|         // For large number of rows, 128 invocations seems to work best. | ||||
|         // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we | ||||
|         // can't use 256 for D==80. | ||||
|         // For scalar, use 128 (arbitrary) | ||||
|         uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); | ||||
|         auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); | ||||
|  | ||||
|         // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. | ||||
|         // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. | ||||
|         const uint32_t D_lsb = D ^ (D & (D-1)); | ||||
|         uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); | ||||
|  | ||||
|         // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads | ||||
|         GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); | ||||
|         return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; | ||||
|     }; | ||||
|  | ||||
| #define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true);     \ | ||||
|  | ||||
| #define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) | ||||
|  | ||||
|     CREATE_FA(GGML_TYPE_F16, f16, true, ) | ||||
|     CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) | ||||
|     CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) | ||||
| #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) | ||||
|     if (device->coopmat2) { | ||||
|  | ||||
|         auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { | ||||
|             return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; | ||||
|         }; | ||||
|  | ||||
|         auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { | ||||
|             // For large number of rows, 128 invocations seems to work best. | ||||
|             // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we | ||||
|             // can't use 256 for D==80. | ||||
|             uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; | ||||
|             auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); | ||||
|             // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads | ||||
|             GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); | ||||
|             return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; | ||||
|         }; | ||||
|  | ||||
| #define CREATE_FA2(TYPE, NAMELC, D) \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc"         #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc"         #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]);     \ | ||||
|  | ||||
| #define CREATE_FA(TYPE, NAMELC) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 64) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 80) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 96) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 112) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 128) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 256) | ||||
|  | ||||
|         CREATE_FA(GGML_TYPE_F16, f16) | ||||
|         CREATE_FA(GGML_TYPE_Q4_0, q4_0) | ||||
|         CREATE_FA(GGML_TYPE_Q4_1, q4_1) | ||||
|         CREATE_FA(GGML_TYPE_Q5_0, q5_0) | ||||
|         CREATE_FA(GGML_TYPE_Q5_1, q5_1) | ||||
|         CREATE_FA(GGML_TYPE_Q8_0, q8_0) | ||||
|         // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently | ||||
|         //CREATE_FA(GGML_TYPE_Q2_K, q2_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q3_K, q3_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q4_K, q4_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q5_K, q5_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q6_K, q6_k) | ||||
|         //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) | ||||
|         //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) | ||||
|         //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) | ||||
|         //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) | ||||
|         //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) | ||||
|         //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) | ||||
|         //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) | ||||
|         //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) | ||||
|         CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) | ||||
|         CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) | ||||
|         CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) | ||||
|         CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) | ||||
|         CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) | ||||
|         CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) | ||||
|         CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) | ||||
|         CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) | ||||
|     } | ||||
| #endif | ||||
| #undef CREATE_FA2 | ||||
| #undef CREATE_FA | ||||
|  | ||||
| #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) | ||||
|     if (device->coopmat2) { | ||||
|  | ||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned} | ||||
| #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ | ||||
|         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \ | ||||
| @@ -2837,6 +2863,9 @@ static vk_device ggml_vk_get_device(size_t idx) { | ||||
|         device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && | ||||
|                                (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); | ||||
|  | ||||
|         device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && | ||||
|                                    (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); | ||||
|  | ||||
|         const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; | ||||
|  | ||||
|         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; | ||||
| @@ -5709,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx | ||||
|     assert(q->type == GGML_TYPE_F32); | ||||
|     assert(k->type == v->type); | ||||
|  | ||||
|     bool scalar = !ctx->device->coopmat2; | ||||
|  | ||||
|     uint32_t gqa_ratio = 1; | ||||
|     uint32_t qk_ratio = neq2 / nek2; | ||||
|     uint32_t workgroups_x = (uint32_t)neq1; | ||||
|     uint32_t workgroups_y = (uint32_t)neq2; | ||||
|     uint32_t workgroups_z = (uint32_t)neq3; | ||||
|  | ||||
|     // For scalar FA, we can use the "large" size to accommodate qga. | ||||
|     // For coopmat FA, we always use the small size (which is still pretty large for gqa). | ||||
|     const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); | ||||
|  | ||||
|     if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && | ||||
|         qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { | ||||
|         // grouped query attention - make the N dimension equal to gqa_ratio, reduce | ||||
|         // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 | ||||
|         // and change addressing calculations to index Q's dimension 2. | ||||
|         gqa_ratio = qk_ratio; | ||||
|         N = gqa_ratio; | ||||
|         workgroups_y /= N; | ||||
|     } | ||||
|  | ||||
|     vk_pipeline *pipelines; | ||||
|     // XXX TODO other backends may be changing accumulator precision to default to f32 soon | ||||
|     bool f32acc = dst->op_params[3] == GGML_PREC_F32; | ||||
|     bool small_rows = N <= flash_attention_num_small_rows; | ||||
|     switch (D) { | ||||
|     case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; | ||||
|     case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; | ||||
|     case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; | ||||
|     case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; | ||||
|     case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; | ||||
|     case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; | ||||
|     default: | ||||
|         assert(!"unsupported D value"); | ||||
|         return; | ||||
|     bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; | ||||
|     bool small_rows = N <= get_fa_num_small_rows(scalar); | ||||
|  | ||||
|     if (scalar) { | ||||
|         switch (D) { | ||||
|         case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; | ||||
|         case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; | ||||
|         case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; | ||||
|         case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; | ||||
|         case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; | ||||
|         case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; | ||||
|         default: | ||||
|             GGML_ASSERT(!"unsupported D value"); | ||||
|             return; | ||||
|         } | ||||
|     } else { | ||||
|         switch (D) { | ||||
|         case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; | ||||
|         case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; | ||||
|         case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; | ||||
|         case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; | ||||
|         case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; | ||||
|         case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; | ||||
|         default: | ||||
|             GGML_ASSERT(!"unsupported D value"); | ||||
|             return; | ||||
|         } | ||||
|     } | ||||
|     assert(pipelines); | ||||
|  | ||||
| @@ -5740,27 +5806,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx | ||||
|     vk_pipeline pipeline = pipelines[aligned]; | ||||
|     assert(pipeline); | ||||
|  | ||||
|     uint32_t gqa_ratio = 1; | ||||
|     uint32_t qk_ratio = neq2 / nek2; | ||||
|     uint32_t workgroups_x = (uint32_t)neq1; | ||||
|     uint32_t workgroups_y = (uint32_t)neq2; | ||||
|     uint32_t workgroups_z = (uint32_t)neq3; | ||||
|  | ||||
|     if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows && | ||||
|         qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { | ||||
|         // grouped query attention - make the N dimension equal to gqa_ratio, reduce | ||||
|         // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 | ||||
|         // and change addressing calculations to index Q's dimension 2. | ||||
|         gqa_ratio = qk_ratio; | ||||
|         N = gqa_ratio; | ||||
|         workgroups_y /= N; | ||||
|     } | ||||
|  | ||||
|     uint32_t split_kv = KV; | ||||
|     uint32_t split_k = 1; | ||||
|  | ||||
|     // Use a placeholder core count if one isn't available. split_k is a big help for perf. | ||||
|     const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; | ||||
|  | ||||
|     // Try to use split_k when KV is large enough to be worth the overhead | ||||
|     if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) { | ||||
|     if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { | ||||
|         // Try to run two workgroups per SM. | ||||
|         split_k = ctx->device->shader_core_count * 2 / workgroups_y; | ||||
|         if (split_k > 1) { | ||||
| @@ -9530,9 +9583,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|         case GGML_OP_FLASH_ATTN_EXT: | ||||
|             { | ||||
|                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; | ||||
|                 if (!ggml_vk_get_device(ctx->device)->coopmat2) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 auto device = ggml_vk_get_device(ctx->device); | ||||
|                 bool coopmat2 = device->coopmat2; | ||||
|                 switch (op->src[0]->ne[0]) { | ||||
|                 case 64: | ||||
|                 case 80: | ||||
| @@ -9540,7 +9592,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|                 case 112: | ||||
|                 case 128: | ||||
|                 case 256: | ||||
|                 case 575: // DeepSeek MLA | ||||
|                     break; | ||||
|                 default: | ||||
|                     return false; | ||||
| @@ -9566,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|                 switch (op->src[1]->type) { | ||||
|                 case GGML_TYPE_F16: | ||||
|                 case GGML_TYPE_Q4_0: | ||||
|                 case GGML_TYPE_Q8_0: | ||||
|                     // supported in scalar and coopmat2 paths | ||||
|                     break; | ||||
|                 case GGML_TYPE_Q4_1: | ||||
|                 case GGML_TYPE_Q5_0: | ||||
|                 case GGML_TYPE_Q5_1: | ||||
|                 case GGML_TYPE_Q8_0: | ||||
|                 // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently | ||||
|                 //case GGML_TYPE_Q2_K: | ||||
|                 //case GGML_TYPE_Q3_K: | ||||
| @@ -9585,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|                 //case GGML_TYPE_IQ3_S: | ||||
|                 //case GGML_TYPE_IQ4_XS: | ||||
|                 case GGML_TYPE_IQ4_NL: | ||||
|                     // currently supported only in coopmat2 path | ||||
|                     if (!coopmat2) { | ||||
|                         return false; | ||||
|                     } | ||||
|                     break; | ||||
|                 default: | ||||
|                     return false; | ||||
|                 } | ||||
|                 if (!coopmat2 && !device->subgroup_shuffle) { | ||||
|                     // scalar FA uses subgroupShuffle | ||||
|                     return false; | ||||
|                 } | ||||
|                 return true; | ||||
|             } | ||||
|         case GGML_OP_GET_ROWS: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz