mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: Dynamic subgroup size support for Q6_K mat_vec (#10536)
* subgroup 64 version with subgroup add. 15% faster scalable version tested for subgroup sizes 16-128 * check for subgroup multiple of 16 and greater than 16 * subgroup sizes are always a power of 2 (https://github.com/KhronosGroup/GLSL/issues/45) * force 16 sequential threads per block * make 16 subgroup size a constant
This commit is contained in:
		| @@ -1231,6 +1231,9 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|  | ||||
|     std::cerr << "ggml_vulkan: Compiling shaders"; | ||||
|  | ||||
|     // some shaders require the subgroup size to be 16 or larger | ||||
|     const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); | ||||
|  | ||||
|     // mulmat | ||||
|     std::vector<uint32_t> l_warptile, m_warptile, s_warptile, | ||||
|                           l_warptile_mmq, m_warptile_mmq, s_warptile_mmq; | ||||
| @@ -1240,11 +1243,11 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|  | ||||
|     l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; | ||||
|     m_warptile = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; | ||||
|     s_warptile = { std::max(device->subgroup_size, 16u),  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|     s_warptile = { subgroup_size_16,  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|  | ||||
|     l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; | ||||
|     m_warptile_mmq = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; | ||||
|     s_warptile_mmq = { std::max(device->subgroup_size, 16u),  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|     s_warptile_mmq = { subgroup_size_16,  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|  | ||||
|     l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; | ||||
|     m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 }; | ||||
| @@ -1431,7 +1434,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32",  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); | ||||
| @@ -1445,7 +1448,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); | ||||
| @@ -1459,7 +1462,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); | ||||
|  | ||||
|     // dequant shaders | ||||
|   | ||||
| @@ -4,9 +4,11 @@ | ||||
|  | ||||
| #include "mul_mat_vec_base.comp" | ||||
|  | ||||
| layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| shared FLOAT_TYPE tmp[32]; | ||||
| layout (constant_id = 0) const uint BLOCK_SIZE = 32; | ||||
|  | ||||
| shared FLOAT_TYPE tmp[BLOCK_SIZE]; | ||||
|  | ||||
| void main() { | ||||
|     const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; | ||||
| @@ -21,21 +23,19 @@ void main() { | ||||
|     const uint num_blocks_per_row = p.ncols / QUANT_K; | ||||
|     const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; | ||||
|  | ||||
|     const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16 | ||||
|     const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1 | ||||
|     // 16 threads are used to process each block | ||||
|     const uint it_size = gl_WorkGroupSize.x/16; | ||||
|     const uint tid = gl_LocalInvocationID.x; | ||||
|     const uint itid = tid%16;  // 0...16 | ||||
|     const uint ix  = tid/16; | ||||
|  | ||||
|     const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8 | ||||
|     const uint step = 8; | ||||
|  | ||||
|     const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128... | ||||
|     const uint v_in = tid - step*v_im;                      // 0...15 or 0...7 | ||||
|     const uint v_im = itid/step;                            // 0 or 1. 0 computes 0..., 1 computes 128... | ||||
|     const uint v_in = itid - step*v_im;                     // 0...15 or 0...7 | ||||
|  | ||||
| #if K_QUANTS_PER_ITERATION == 1 | ||||
|     const uint l0 = v_in;                                   // 0...15 | ||||
|     const uint is = 0; | ||||
| #else | ||||
|     const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28 | ||||
|     const uint is = v_in / 4; | ||||
| #endif | ||||
|  | ||||
|     const uint ql_offset = 64*v_im + l0; | ||||
|     const uint qh_offset = 32*v_im + l0; | ||||
| @@ -44,7 +44,7 @@ void main() { | ||||
|  | ||||
|     FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp | ||||
|  | ||||
|     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { | ||||
|     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { | ||||
|         const uint y_idx   = i * QUANT_K + y_offset; | ||||
|  | ||||
|         const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); | ||||
| @@ -95,10 +95,10 @@ void main() { | ||||
|     } | ||||
|  | ||||
|     tmp[gl_LocalInvocationID.x] = temp; | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
|  | ||||
|     barrier(); | ||||
|     [[unroll]] for (uint s = 16; s > 0; s >>= 1) { | ||||
|     [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { | ||||
|         if (tid < s) { | ||||
|             tmp[tid] += tmp[tid + s]; | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Eve
					Eve