mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: optimize and reenable split_k (#10637)
Use vector loads when possible in mul_mat_split_k_reduce. Use split_k when there aren't enough workgroups to fill the shaders.
This commit is contained in:
		| @@ -165,6 +165,7 @@ struct vk_device_struct { | |||||||
|     vk_queue transfer_queue; |     vk_queue transfer_queue; | ||||||
|     bool single_queue; |     bool single_queue; | ||||||
|     uint32_t subgroup_size; |     uint32_t subgroup_size; | ||||||
|  |     uint32_t shader_core_count; | ||||||
|     bool uma; |     bool uma; | ||||||
|  |  | ||||||
|     size_t idx; |     size_t idx; | ||||||
| @@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); | ||||||
|     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); | ||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); | ||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); | ||||||
|     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); | ||||||
| @@ -1610,11 +1611,14 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|         const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties(); |         const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties(); | ||||||
|  |  | ||||||
|         bool maintenance4_support = false; |         bool maintenance4_support = false; | ||||||
|  |         bool sm_builtins = false; | ||||||
|  |  | ||||||
|         // Check if maintenance4 is supported |         // Check if maintenance4 is supported | ||||||
|         for (const auto& properties : ext_props) { |         for (const auto& properties : ext_props) { | ||||||
|             if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { |             if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { | ||||||
|                 maintenance4_support = true; |                 maintenance4_support = true; | ||||||
|  |             } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { | ||||||
|  |                 sm_builtins = true; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -1622,11 +1626,21 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|         vk::PhysicalDeviceMaintenance3Properties props3; |         vk::PhysicalDeviceMaintenance3Properties props3; | ||||||
|         vk::PhysicalDeviceMaintenance4Properties props4; |         vk::PhysicalDeviceMaintenance4Properties props4; | ||||||
|         vk::PhysicalDeviceSubgroupProperties subgroup_props; |         vk::PhysicalDeviceSubgroupProperties subgroup_props; | ||||||
|  |         vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; | ||||||
|         props2.pNext = &props3; |         props2.pNext = &props3; | ||||||
|         props3.pNext = &subgroup_props; |         props3.pNext = &subgroup_props; | ||||||
|  |  | ||||||
|  |         VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props; | ||||||
|  |  | ||||||
|         if (maintenance4_support) { |         if (maintenance4_support) { | ||||||
|             subgroup_props.pNext = &props4; |             last_struct->pNext = (VkBaseOutStructure *)&props4; | ||||||
|  |             last_struct = (VkBaseOutStructure *)&props4; | ||||||
|         } |         } | ||||||
|  |         if (sm_builtins) { | ||||||
|  |             last_struct->pNext = (VkBaseOutStructure *)&sm_props; | ||||||
|  |             last_struct = (VkBaseOutStructure *)&sm_props; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         device->physical_device.getProperties2(&props2); |         device->physical_device.getProperties2(&props2); | ||||||
|         device->properties = props2.properties; |         device->properties = props2.properties; | ||||||
|  |  | ||||||
| @@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||||||
|         device->vendor_id = device->properties.vendorID; |         device->vendor_id = device->properties.vendorID; | ||||||
|         device->subgroup_size = subgroup_props.subgroupSize; |         device->subgroup_size = subgroup_props.subgroupSize; | ||||||
|         device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; |         device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; | ||||||
|  |         if (sm_builtins) { | ||||||
|  |             device->shader_core_count = sm_props.shaderSMCount; | ||||||
|  |         } else { | ||||||
|  |             device->shader_core_count = 0; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         bool fp16_storage = false; |         bool fp16_storage = false; | ||||||
|         bool fp16_compute = false; |         bool fp16_compute = false; | ||||||
| @@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz | |||||||
|     dst->device->device.resetFences({ dst->device->fence }); |     dst->device->device.resetFences({ dst->device->fence }); | ||||||
| } | } | ||||||
|  |  | ||||||
| static uint32_t ggml_vk_guess_split_k(int m, int n, int k) { | static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { | ||||||
|     VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); |     VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); | ||||||
|     // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) { |  | ||||||
|     //     return 4; |  | ||||||
|     // } |  | ||||||
|  |  | ||||||
|     return 1; |     uint32_t split_k = 1; | ||||||
|  |     if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { | ||||||
|  |         // If k is 'large' and the SMs will fill less than halfway, use split_k. | ||||||
|  |         uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); | ||||||
|  |         uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); | ||||||
|  |         if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { | ||||||
|  |             split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); | ||||||
|  |             // Clamp to 2 or 4 | ||||||
|  |             split_k = std::min(split_k, 4u); | ||||||
|  |             if (split_k == 3) { | ||||||
|  |                 split_k = 2; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k); |     return split_k; | ||||||
| } | } | ||||||
|  |  | ||||||
| static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { | static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { | ||||||
| @@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub | |||||||
|     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); |     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); | ||||||
|     const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; |     const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; | ||||||
|  |  | ||||||
|     const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); |  | ||||||
|  |  | ||||||
|     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); |     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); | ||||||
|  |  | ||||||
|  |     const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); | ||||||
|  |  | ||||||
|     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); |     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); | ||||||
|     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); |     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); | ||||||
|     const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; |     const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; | ||||||
| @@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub | |||||||
|     if (dryrun) { |     if (dryrun) { | ||||||
|         const uint64_t x_sz_upd = x_sz * ne02 * ne03; |         const uint64_t x_sz_upd = x_sz * ne02 * ne03; | ||||||
|         const uint64_t y_sz_upd = y_sz * ne12 * ne13; |         const uint64_t y_sz_upd = y_sz * ne12 * ne13; | ||||||
|         const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0; |         const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; | ||||||
|         if ( |         if ( | ||||||
|                 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || |                 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || | ||||||
|                 (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || |                 (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || | ||||||
|   | |||||||
| @@ -5,7 +5,9 @@ | |||||||
| layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer A {float data_a[];}; | layout (binding = 0) readonly buffer A {float data_a[];}; | ||||||
|  | layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; | ||||||
| layout (binding = 1) writeonly buffer D {float data_d[];}; | layout (binding = 1) writeonly buffer D {float data_d[];}; | ||||||
|  | layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; | ||||||
|  |  | ||||||
| layout (push_constant) uniform parameter { | layout (push_constant) uniform parameter { | ||||||
|     uint ne; |     uint ne; | ||||||
| @@ -13,17 +15,34 @@ layout (push_constant) uniform parameter { | |||||||
| } p; | } p; | ||||||
|  |  | ||||||
| void main() { | void main() { | ||||||
|     const uint idx = gl_GlobalInvocationID.x; |     // Each invocation handles four consecutive components | ||||||
|  |     const uint idx = gl_GlobalInvocationID.x * 4; | ||||||
|  |  | ||||||
|     if (idx >= p.ne) { |     if (idx >= p.ne) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     float result = 0.0f; |     // Check if all four components are in bounds and aligned, | ||||||
|  |     // then use vector loads | ||||||
|  |     if (idx + 3 < p.ne && (p.ne % 4) == 0) { | ||||||
|  |         vec4 result = vec4(0.0f); | ||||||
|  |  | ||||||
|     [[unroll]] for (uint i = 0; i < p.k_num; i++) { |         [[unroll]] for (uint i = 0; i < p.k_num; i++) { | ||||||
|         result += data_a[i * p.ne + idx]; |             result += data_a4[(i * p.ne + idx) / 4]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         data_d4[idx / 4] = result; | ||||||
|  |     } else { | ||||||
|  |         [[unroll]] for (uint j = 0; j < 4; ++j) { | ||||||
|  |             if (idx + j < p.ne) { | ||||||
|  |                 float result = 0.0f; | ||||||
|  |  | ||||||
|  |                 [[unroll]] for (uint i = 0; i < p.k_num; i++) { | ||||||
|  |                     result += data_a[i * p.ne + idx + j]; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 data_d[idx + j] = result; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     data_d[idx] = result; |  | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz