mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats (#10721)
* Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats * Fix subgroup size control extension support check Add accf32 and accf16 checks for coopmats * Also disable coopmats on amdvlk
This commit is contained in:
		| @@ -163,7 +163,11 @@ struct vk_device_struct { | |||||||
|     uint32_t shader_core_count; |     uint32_t shader_core_count; | ||||||
|     bool uma; |     bool uma; | ||||||
|     bool float_controls_rte_fp16; |     bool float_controls_rte_fp16; | ||||||
|     bool coopmat2; |  | ||||||
|  |     bool subgroup_size_control; | ||||||
|  |     uint32_t subgroup_min_size; | ||||||
|  |     uint32_t subgroup_max_size; | ||||||
|  |     bool subgroup_require_full_support; | ||||||
|  |  | ||||||
|     bool coopmat_support; |     bool coopmat_support; | ||||||
|     bool coopmat_acc_f32_support; |     bool coopmat_acc_f32_support; | ||||||
| @@ -171,6 +175,7 @@ struct vk_device_struct { | |||||||
|     uint32_t coopmat_m; |     uint32_t coopmat_m; | ||||||
|     uint32_t coopmat_n; |     uint32_t coopmat_n; | ||||||
|     uint32_t coopmat_k; |     uint32_t coopmat_k; | ||||||
|  |     bool coopmat2; | ||||||
|  |  | ||||||
|     size_t idx; |     size_t idx; | ||||||
|  |  | ||||||
| @@ -749,8 +754,12 @@ static uint32_t compile_count = 0; | |||||||
| static std::mutex compile_count_mutex; | static std::mutex compile_count_mutex; | ||||||
| static std::condition_variable compile_count_cond; | static std::condition_variable compile_count_cond; | ||||||
|  |  | ||||||
| static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align, bool disable_robustness) { | static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, | ||||||
|     VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")"); |                                          uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, | ||||||
|  |                                          uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { | ||||||
|  |     VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << | ||||||
|  |                  ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << | ||||||
|  |                  ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); | ||||||
|     GGML_ASSERT(parameter_count > 0); |     GGML_ASSERT(parameter_count > 0); | ||||||
|     GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT |     GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT | ||||||
|  |  | ||||||
| @@ -809,14 +818,28 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin | |||||||
|         specialization_constants.data() |         specialization_constants.data() | ||||||
|     ); |     ); | ||||||
|  |  | ||||||
|  |     vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; | ||||||
|  |  | ||||||
|  |     if (device->subgroup_require_full_support && require_full_subgroups) { | ||||||
|  |         pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( |     vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( | ||||||
|             vk::PipelineShaderStageCreateFlags(), |             pipeline_shader_stage_create_flags, | ||||||
|             vk::ShaderStageFlagBits::eCompute, |             vk::ShaderStageFlagBits::eCompute, | ||||||
|             pipeline->shader_module, |             pipeline->shader_module, | ||||||
|             entrypoint.c_str(), |             entrypoint.c_str(), | ||||||
|             &specialization_info); |             &specialization_info); | ||||||
|  |  | ||||||
|  |     vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; | ||||||
|  |     pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; | ||||||
|  |     if (device->subgroup_size_control && required_subgroup_size > 0) { | ||||||
|  |         GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); | ||||||
|  |         pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     vk::ComputePipelineCreateInfo compute_pipeline_create_info( |     vk::ComputePipelineCreateInfo compute_pipeline_create_info( | ||||||
|         vk::PipelineCreateFlags(), |         vk::PipelineCreateFlags{}, | ||||||
|         pipeline_shader_create_info, |         pipeline_shader_create_info, | ||||||
|         pipeline->layout); |         pipeline->layout); | ||||||
|  |  | ||||||
| @@ -1496,7 +1519,9 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|     device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>(); |     device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>(); | ||||||
|  |  | ||||||
|     std::vector<std::future<void>> compiles; |     std::vector<std::future<void>> compiles; | ||||||
|     auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) { |     auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, | ||||||
|  |                                               uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, | ||||||
|  |                                               uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { | ||||||
|         { |         { | ||||||
|             // wait until fewer than N compiles are in progress |             // wait until fewer than N compiles are in progress | ||||||
|             uint32_t N = std::max(1u, std::thread::hardware_concurrency()); |             uint32_t N = std::max(1u, std::thread::hardware_concurrency()); | ||||||
| @@ -1506,7 +1531,8 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|             } |             } | ||||||
|             compile_count++; |             compile_count++; | ||||||
|         } |         } | ||||||
|         compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness)); |         compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, | ||||||
|  |                                       parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size)); | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
| #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) | #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) | ||||||
| @@ -1612,40 +1638,59 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned} |         // Create 6 variants, {s,m,l}x{unaligned,aligned} | ||||||
| #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||||
|         if (device->mul_mat ## ID ## _l) \ |         if (device->mul_mat ## ID ## _l) \ | ||||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \ |             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true);   \ | ||||||
|         if (device->mul_mat ## ID ## _m) \ |         if (device->mul_mat ## ID ## _m) \ | ||||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \ |             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true);   \ | ||||||
|         if (device->mul_mat ## ID ## _s) \ |         if (device->mul_mat ## ID ## _s) \ | ||||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \ |             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true);   \ | ||||||
|         if (device->mul_mat ## ID ## _l) \ |         if (device->mul_mat ## ID ## _l) \ | ||||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \ |             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true);   \ | ||||||
|         if (device->mul_mat ## ID ## _m) \ |         if (device->mul_mat ## ID ## _m) \ | ||||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \ |             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true);   \ | ||||||
|         if (device->mul_mat ## ID ## _s) \ |         if (device->mul_mat ## ID ## _s) \ | ||||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \ |             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true);   \ | ||||||
|  |  | ||||||
|         // Create 2 variants, {f16,f32} accumulator |         // Create 2 variants, {f16,f32} accumulator | ||||||
| #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||||
|         CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ |         if (device->coopmat_acc_f16_support) { \ | ||||||
|         CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ |             CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||||
|  |         } \ | ||||||
|  |         if (device->coopmat_acc_f32_support) { \ | ||||||
|  |             CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||||
|  |         } \ | ||||||
|  |  | ||||||
|         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |         CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||||
|  |  | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |         if (device->coopmat_acc_f16_support) { | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |  | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |         } else { | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |  | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. |         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. | ||||||
|         if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { |         if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { | ||||||
| @@ -1653,19 +1698,35 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|             CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); |             CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||||
|             CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); |             CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||||
|  |  | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |             if (device->coopmat_acc_f16_support) { | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |  | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |             } else { | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |  | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  | #undef CREATE_MM2 | ||||||
| #undef CREATE_MM | #undef CREATE_MM | ||||||
|     } else if (device->fp16) { |     } else if (device->fp16) { | ||||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned} |         // Create 6 variants, {s,m,l}x{unaligned,aligned} | ||||||
| @@ -1683,6 +1744,11 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|         if (device->mul_mat ## ID ## _s) \ |         if (device->mul_mat ## ID ## _s) \ | ||||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \ |             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \ | ||||||
|  |  | ||||||
|  |         // Create 2 variants, {f16,f32} accumulator | ||||||
|  | #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||||
|  |         CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||||
|  |         CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||||
|  |  | ||||||
|         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||||
|         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); |         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||||
| @@ -1720,6 +1786,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|         } |         } | ||||||
|  | #undef CREATE_MM2 | ||||||
| #undef CREATE_MM | #undef CREATE_MM | ||||||
|     } else { |     } else { | ||||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned} |         // Create 6 variants, {s,m,l}x{unaligned,aligned} | ||||||
| @@ -1774,7 +1841,6 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); |             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||||
|         } |         } | ||||||
| #undef CREATE_MM2 |  | ||||||
| #undef CREATE_MM | #undef CREATE_MM | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -1998,6 +2064,8 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|                 amd_shader_core_properties2 = true; |                 amd_shader_core_properties2 = true; | ||||||
|             } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { |             } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { | ||||||
|                 pipeline_robustness = true; |                 pipeline_robustness = true; | ||||||
|  |             } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { | ||||||
|  |                 device->subgroup_size_control = true; | ||||||
|             } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && |             } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && | ||||||
|                        !getenv("GGML_VK_DISABLE_COOPMAT")) { |                        !getenv("GGML_VK_DISABLE_COOPMAT")) { | ||||||
|                 device->coopmat_support = true; |                 device->coopmat_support = true; | ||||||
| @@ -2018,6 +2086,8 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|         vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; |         vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; | ||||||
|         vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; |         vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; | ||||||
|         vk::PhysicalDeviceVulkan12Properties vk12_props; |         vk::PhysicalDeviceVulkan12Properties vk12_props; | ||||||
|  |         vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; | ||||||
|  |  | ||||||
|         props2.pNext = &props3; |         props2.pNext = &props3; | ||||||
|         props3.pNext = &subgroup_props; |         props3.pNext = &subgroup_props; | ||||||
|         subgroup_props.pNext = &driver_props; |         subgroup_props.pNext = &driver_props; | ||||||
| @@ -2037,6 +2107,10 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|             last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; |             last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; | ||||||
|             last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; |             last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; | ||||||
|         } |         } | ||||||
|  |         if (device->subgroup_size_control) { | ||||||
|  |             last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; | ||||||
|  |             last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; | ||||||
|  |         } | ||||||
|  |  | ||||||
| #if defined(VK_NV_cooperative_matrix2) | #if defined(VK_NV_cooperative_matrix2) | ||||||
|         vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; |         vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; | ||||||
| @@ -2075,7 +2149,7 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|  |  | ||||||
|         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; |         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; | ||||||
|  |  | ||||||
|         if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) { |         if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) { | ||||||
|             // Intel drivers don't support coopmat properly yet |             // Intel drivers don't support coopmat properly yet | ||||||
|             // Only RADV supports coopmat properly on AMD |             // Only RADV supports coopmat properly on AMD | ||||||
|             device->coopmat_support = false; |             device->coopmat_support = false; | ||||||
| @@ -2131,6 +2205,17 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|             device_extensions.push_back("VK_EXT_pipeline_robustness"); |             device_extensions.push_back("VK_EXT_pipeline_robustness"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; | ||||||
|  |         subgroup_size_control_features.pNext = nullptr; | ||||||
|  |         subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; | ||||||
|  |         subgroup_size_control_features.computeFullSubgroups = false; | ||||||
|  |         subgroup_size_control_features.subgroupSizeControl = false; | ||||||
|  |  | ||||||
|  |         if (device->subgroup_size_control) { | ||||||
|  |             last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; | ||||||
|  |             last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; |         VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; | ||||||
|         coopmat_features.pNext = nullptr; |         coopmat_features.pNext = nullptr; | ||||||
|         coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; |         coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; | ||||||
| @@ -2158,6 +2243,17 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|  |  | ||||||
|         device->pipeline_robustness = pl_robustness_features.pipelineRobustness; |         device->pipeline_robustness = pl_robustness_features.pipelineRobustness; | ||||||
|  |  | ||||||
|  |         device->subgroup_size_control = device->subgroup_size_control && | ||||||
|  |                 (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && | ||||||
|  |                 subgroup_size_control_features.subgroupSizeControl; | ||||||
|  |  | ||||||
|  |         if (device->subgroup_size_control) { | ||||||
|  |             device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; | ||||||
|  |             device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; | ||||||
|  |             device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; | ||||||
|  |             device_extensions.push_back("VK_EXT_subgroup_size_control"); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; |         device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; | ||||||
|  |  | ||||||
|         if (coopmat2_support) { |         if (coopmat2_support) { | ||||||
| @@ -2307,7 +2403,7 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (device->coopmat_m == 0) { |             if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { | ||||||
|                 // No suitable matmul mode found |                 // No suitable matmul mode found | ||||||
|                 GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); |                 GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); | ||||||
|                 device->coopmat_support = false; |                 device->coopmat_support = false; | ||||||
| @@ -2440,7 +2536,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) { |     if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) { | ||||||
|         // Intel drivers don't support coopmat properly yet |         // Intel drivers don't support coopmat properly yet | ||||||
|         // Only RADV supports coopmat properly on AMD |         // Only RADV supports coopmat properly on AMD | ||||||
|         coopmat_support = false; |         coopmat_support = false; | ||||||
| @@ -2727,7 +2823,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte | |||||||
|     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { |     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { | ||||||
|         return ctx->device->pipeline_matmul_f32_f16; |         return ctx->device->pipeline_matmul_f32_f16; | ||||||
|     } |     } | ||||||
|     if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) { |     if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { | ||||||
|         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { |         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { | ||||||
|             return ctx->device->pipeline_matmul_f16_f32.f16acc; |             return ctx->device->pipeline_matmul_f16_f32.f16acc; | ||||||
|         } |         } | ||||||
| @@ -2802,7 +2898,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co | |||||||
|     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { |     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { | ||||||
|         return ctx->device->pipeline_matmul_id_f32; |         return ctx->device->pipeline_matmul_id_f32; | ||||||
|     } |     } | ||||||
|     if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) { |     if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { | ||||||
|         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { |         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { | ||||||
|             return ctx->device->pipeline_matmul_id_f16_f32.f16acc; |             return ctx->device->pipeline_matmul_id_f16_f32.f16acc; | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 0cc4m
					0cc4m