mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash attention (#10206)
This commit is contained in:
		| @@ -167,6 +167,7 @@ struct vk_device_struct { | ||||
|     uint32_t subgroup_size; | ||||
|     uint32_t shader_core_count; | ||||
|     bool uma; | ||||
|     bool coopmat2; | ||||
|  | ||||
|     size_t idx; | ||||
|  | ||||
| @@ -176,6 +177,7 @@ struct vk_device_struct { | ||||
|     vk_matmul_pipeline2 pipeline_matmul_f16_f32; | ||||
|     vk_pipeline pipeline_matmul_split_k_reduce; | ||||
|  | ||||
|     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; | ||||
|     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; | ||||
|  | ||||
|     vk_matmul_pipeline pipeline_matmul_id_f32; | ||||
| @@ -229,6 +231,14 @@ struct vk_device_struct { | ||||
|     vk_pipeline pipeline_timestep_embedding_f32; | ||||
|     vk_pipeline pipeline_pool2d_f32; | ||||
|  | ||||
|     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; | ||||
|  | ||||
|     std::unordered_map<std::string, vk_pipeline_ref> pipelines; | ||||
|     std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements; | ||||
|  | ||||
| @@ -340,6 +350,40 @@ struct vk_mat_vec_id_push_constants { | ||||
|     uint32_t nei0; uint32_t ne11; | ||||
| }; | ||||
|  | ||||
| struct vk_flash_attn_push_constants { | ||||
|     uint32_t N; | ||||
|     uint32_t KV; | ||||
|  | ||||
|     uint32_t ne1; | ||||
|     uint32_t ne2; | ||||
|     uint32_t ne3; | ||||
|  | ||||
|     uint32_t neq2; | ||||
|     uint32_t neq3; | ||||
|     uint32_t nek2; | ||||
|     uint32_t nek3; | ||||
|     uint32_t nev2; | ||||
|     uint32_t nev3; | ||||
|     uint32_t nem1; | ||||
|  | ||||
|     uint32_t nb02; | ||||
|     uint32_t nb03; | ||||
|     uint32_t nb12; | ||||
|     uint32_t nb13; | ||||
|     uint32_t nb22; | ||||
|     uint32_t nb23; | ||||
|     uint32_t nb31; | ||||
|  | ||||
|     float scale; | ||||
|     float max_bias; | ||||
|     float logit_softcap; | ||||
|  | ||||
|     uint32_t mask; | ||||
|     uint32_t n_head_log2; | ||||
|     float m0; | ||||
|     float m1; | ||||
| }; | ||||
|  | ||||
| struct vk_op_push_constants { | ||||
|     uint32_t KX; | ||||
|     uint32_t KY; | ||||
| @@ -1265,6 +1309,23 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // number of rows/cols for flash attention shader | ||||
| static constexpr uint32_t flash_attention_num_small_rows = 32; | ||||
| static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { | ||||
|     GGML_UNUSED(clamp); | ||||
|  | ||||
|     // small rows, large cols | ||||
|     if (small_rows) { | ||||
|         return {flash_attention_num_small_rows, 128}; | ||||
|     } | ||||
|     // small cols to reduce register count | ||||
|     if (ggml_is_quantized(type) || D == 256) { | ||||
|         return {64, 32}; | ||||
|     } | ||||
|     return {64, 64}; | ||||
| }; | ||||
|  | ||||
|  | ||||
| static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); | ||||
|  | ||||
| @@ -1275,59 +1336,98 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|  | ||||
|     // mulmat | ||||
|     std::vector<uint32_t> l_warptile, m_warptile, s_warptile, | ||||
|                           l_warptile_mmq, m_warptile_mmq, s_warptile_mmq; | ||||
|                           l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, | ||||
|                           l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, | ||||
|                           l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; | ||||
|     std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms, | ||||
|                             l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms; | ||||
|                             l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, | ||||
|                             l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, | ||||
|                             l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; | ||||
|  | ||||
|     uint32_t l_align, m_align, s_align; | ||||
|     if (device->coopmat2) { | ||||
|         // spec constants and tile sizes for non-quant matmul/matmul_id | ||||
|         l_warptile = { 256, 128, 256, 64 }; | ||||
|         m_warptile = { 256, 128, 128, 64 }; | ||||
|         s_warptile = { 128,  32,  16, 64 }; | ||||
|         l_wg_denoms = {128, 256, 1 }; | ||||
|         m_wg_denoms = {128, 128, 1 }; | ||||
|         s_wg_denoms = { 32,  16, 1 }; | ||||
|  | ||||
|     l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; | ||||
|     m_warptile = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; | ||||
|     s_warptile = { subgroup_size_16,  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|         // spec constants and tile sizes for quant matmul (non-Qi_K) | ||||
|         l_warptile_mmq = { 256, 128, 256, 64 }; | ||||
|         m_warptile_mmq = { 256, 128, 128, 64 }; | ||||
|         s_warptile_mmq = { 256, 128, 128, 64 }; | ||||
|         l_mmq_wg_denoms = { 128, 256, 1 }; | ||||
|         m_mmq_wg_denoms = { 128, 128, 1 }; | ||||
|         s_mmq_wg_denoms = { 128, 128, 1 }; | ||||
|  | ||||
|     l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; | ||||
|     m_warptile_mmq = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; | ||||
|     s_warptile_mmq = { subgroup_size_16,  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|         // spec constants and tile sizes for quant matmul (Qi_K) | ||||
|         l_warptile_mmq_k = { 256, 128, 512, 16 }; | ||||
|         m_warptile_mmq_k = { 256, 128, 256, 16 }; | ||||
|         s_warptile_mmq_k = { 256, 32, 128, 64 }; | ||||
|         l_mmq_wg_denoms_k = { 128, 512, 1 }; | ||||
|         m_mmq_wg_denoms_k = { 128, 256, 1 }; | ||||
|         s_mmq_wg_denoms_k = { 32, 128, 1 }; | ||||
|  | ||||
|     l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; | ||||
|     m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 }; | ||||
|     s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 }; | ||||
|         // spec constants and tile sizes for quant matmul_id | ||||
|         l_warptile_mmqid = { 256, 128, 128, 16 }; | ||||
|         m_warptile_mmqid = { 256, 128, 64, 16 }; | ||||
|         s_warptile_mmqid = { 256, 64, 64, 16 }; | ||||
|         l_mmqid_wg_denoms = { 128, 128, 1 }; | ||||
|         m_mmqid_wg_denoms = { 128, 64, 1 }; | ||||
|         s_mmqid_wg_denoms = { 64, 64, 1 }; | ||||
|  | ||||
|     l_align = 128; | ||||
|     m_align =  64; | ||||
|     s_align =  32; | ||||
|         l_align = 128; | ||||
|         m_align =  64; | ||||
|         s_align =  32; | ||||
|     } else { | ||||
|         l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; | ||||
|         m_warptile = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; | ||||
|         s_warptile = { subgroup_size_16,  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|         l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; | ||||
|         m_warptile_mmq = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; | ||||
|         s_warptile_mmq = { subgroup_size_16,  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; | ||||
|         l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; | ||||
|         m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 }; | ||||
|         s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 }; | ||||
|         l_align = 128; | ||||
|         m_align =  64; | ||||
|         s_align =  32; | ||||
|  | ||||
|     // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders | ||||
|     // and tile sizes, this should handle 16KB, 32KB, and 48KB+. | ||||
|     // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. | ||||
|     // But the numbers happen to work out for 32KB shared memory size that when using the medium | ||||
|     // size there's enough room for everything, and we assert for this. | ||||
|     uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); | ||||
|     if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { | ||||
|         l_warptile = m_warptile; | ||||
|         l_wg_denoms = m_wg_denoms; | ||||
|         shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); | ||||
|         GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|     } | ||||
|     if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { | ||||
|         // assert mul_mat_mat_id shaders will fit. | ||||
|         GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|     } | ||||
|  | ||||
|     shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); | ||||
|     if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { | ||||
|         if (device->properties.limits.maxComputeSharedMemorySize == 32768) { | ||||
|             l_warptile_mmq = m_warptile_mmq; | ||||
|             l_mmq_wg_denoms = m_mmq_wg_denoms; | ||||
|         } else { | ||||
|             l_warptile_mmq = s_warptile_mmq; | ||||
|             l_mmq_wg_denoms = s_mmq_wg_denoms; | ||||
|         // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders | ||||
|         // and tile sizes, this should handle 16KB, 32KB, and 48KB+. | ||||
|         // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. | ||||
|         // But the numbers happen to work out for 32KB shared memory size that when using the medium | ||||
|         // size there's enough room for everything, and we assert for this. | ||||
|         uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); | ||||
|         if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { | ||||
|             l_warptile = m_warptile; | ||||
|             l_wg_denoms = m_wg_denoms; | ||||
|             shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); | ||||
|             GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
|         if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { | ||||
|             // assert mul_mat_mat_id shaders will fit. | ||||
|             GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
|  | ||||
|         shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); | ||||
|         GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|     } | ||||
|     if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { | ||||
|         // assert mul_mat_mat_id shaders will fit. | ||||
|         GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { | ||||
|             if (device->properties.limits.maxComputeSharedMemorySize == 32768) { | ||||
|                 l_warptile_mmq = m_warptile_mmq; | ||||
|                 l_mmq_wg_denoms = m_mmq_wg_denoms; | ||||
|             } else { | ||||
|                 l_warptile_mmq = s_warptile_mmq; | ||||
|                 l_mmq_wg_denoms = s_mmq_wg_denoms; | ||||
|             } | ||||
|             shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); | ||||
|             GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
|         if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { | ||||
|             // assert mul_mat_mat_id shaders will fit. | ||||
|             GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>(); | ||||
| @@ -1362,6 +1462,105 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|         compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness)); | ||||
|     }; | ||||
|  | ||||
| #if defined(VK_NV_cooperative_matrix2) | ||||
|     if (device->coopmat2) { | ||||
|  | ||||
|         auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { | ||||
|             return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; | ||||
|         }; | ||||
|  | ||||
|         auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { | ||||
|             // For large number of rows, 128 invocations seems to work best. | ||||
|             // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we | ||||
|             // can't use 256 for D==80. | ||||
|             uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; | ||||
|             auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); | ||||
|             return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; | ||||
|         }; | ||||
|  | ||||
| #define CREATE_FA2(TYPE, NAMELC, D) \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc"         #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc"         #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1);     \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]);     \ | ||||
|  | ||||
| #define CREATE_FA(TYPE, NAMELC) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 64) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 80) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 96) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 112) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 128) \ | ||||
|         CREATE_FA2(TYPE, NAMELC, 256) | ||||
|  | ||||
|         CREATE_FA(GGML_TYPE_F16, f16) | ||||
|         CREATE_FA(GGML_TYPE_Q4_0, q4_0) | ||||
|         CREATE_FA(GGML_TYPE_Q4_1, q4_1) | ||||
|         CREATE_FA(GGML_TYPE_Q5_0, q5_0) | ||||
|         CREATE_FA(GGML_TYPE_Q5_1, q5_1) | ||||
|         CREATE_FA(GGML_TYPE_Q8_0, q8_0) | ||||
|         // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently | ||||
|         //CREATE_FA(GGML_TYPE_Q2_K, q2_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q3_K, q3_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q4_K, q4_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q5_K, q5_k) | ||||
|         //CREATE_FA(GGML_TYPE_Q6_K, q6_k) | ||||
|         CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) | ||||
| #undef CREATE_FA | ||||
|  | ||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned} | ||||
| #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ | ||||
|         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \ | ||||
|         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \ | ||||
|         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \ | ||||
|         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \ | ||||
|         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \ | ||||
|         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \ | ||||
|  | ||||
|         // Create 2 variants, {f16,f32} accumulator | ||||
| #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ | ||||
|         CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \ | ||||
|         CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \ | ||||
|  | ||||
|         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) | ||||
|  | ||||
|         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) | ||||
|  | ||||
|         CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) | ||||
|  | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) | ||||
| #undef CREATE_MM | ||||
| #undef CREATE_MM2 | ||||
|     } else | ||||
| #endif | ||||
|     if (device->fp16) { | ||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned} | ||||
| #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ | ||||
| @@ -1648,15 +1847,28 @@ static vk_device ggml_vk_get_device(size_t idx) { | ||||
|         device->physical_device = physical_devices[dev_num]; | ||||
|         const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties(); | ||||
|  | ||||
|         bool fp16_storage = false; | ||||
|         bool fp16_compute = false; | ||||
|         bool maintenance4_support = false; | ||||
|         bool sm_builtins = false; | ||||
|         bool pipeline_robustness = false; | ||||
|         bool coopmat2_support = false; | ||||
|  | ||||
|         // Check if maintenance4 is supported | ||||
|         for (const auto& properties : ext_props) { | ||||
|             if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { | ||||
|                 maintenance4_support = true; | ||||
|             } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { | ||||
|                 fp16_storage = true; | ||||
|             } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { | ||||
|                 fp16_compute = true; | ||||
|             } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { | ||||
|                 sm_builtins = true; | ||||
|             } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { | ||||
|                 pipeline_robustness = true; | ||||
|             } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && | ||||
|                        !getenv("GGML_VULKAN_DISABLE_COOPMAT2")) { | ||||
|                 coopmat2_support = true; | ||||
|             } | ||||
|         } | ||||
|  | ||||
| @@ -1679,6 +1891,14 @@ static vk_device ggml_vk_get_device(size_t idx) { | ||||
|             last_struct = (VkBaseOutStructure *)&sm_props; | ||||
|         } | ||||
|  | ||||
| #if defined(VK_NV_cooperative_matrix2) | ||||
|         vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; | ||||
|         if (coopmat2_support) { | ||||
|             last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; | ||||
|             last_struct = (VkBaseOutStructure *)&coopmat2_props; | ||||
|         } | ||||
| #endif | ||||
|  | ||||
|         device->physical_device.getProperties2(&props2); | ||||
|         device->properties = props2.properties; | ||||
|  | ||||
| @@ -1701,20 +1921,6 @@ static vk_device ggml_vk_get_device(size_t idx) { | ||||
|             device->shader_core_count = 0; | ||||
|         } | ||||
|  | ||||
|         bool fp16_storage = false; | ||||
|         bool fp16_compute = false; | ||||
|         bool pipeline_robustness = false; | ||||
|  | ||||
|         for (const auto& properties : ext_props) { | ||||
|             if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { | ||||
|                 fp16_storage = true; | ||||
|             } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { | ||||
|                 fp16_compute = true; | ||||
|             } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { | ||||
|                 pipeline_robustness = true; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); | ||||
|         const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; | ||||
|  | ||||
| @@ -1757,22 +1963,112 @@ static vk_device ggml_vk_get_device(size_t idx) { | ||||
|         vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; | ||||
|         vk11_features.pNext = &vk12_features; | ||||
|  | ||||
|         last_struct = (VkBaseOutStructure *)&vk12_features; | ||||
|  | ||||
|         VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; | ||||
|         pl_robustness_features.pNext = nullptr; | ||||
|         pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; | ||||
|         pl_robustness_features.pipelineRobustness = VK_FALSE; | ||||
|  | ||||
|         if (pipeline_robustness) { | ||||
|             vk12_features.pNext = &pl_robustness_features; | ||||
|             last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; | ||||
|             last_struct = (VkBaseOutStructure *)&pl_robustness_features; | ||||
|             device_extensions.push_back("VK_EXT_pipeline_robustness"); | ||||
|         } | ||||
|  | ||||
| #if defined(VK_NV_cooperative_matrix2) | ||||
|         VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; | ||||
|         coopmat2_features.pNext = nullptr; | ||||
|         coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; | ||||
|         if (coopmat2_support) { | ||||
|             last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; | ||||
|             last_struct = (VkBaseOutStructure *)&coopmat2_features; | ||||
|             device_extensions.push_back("VK_NV_cooperative_matrix2"); | ||||
|         } | ||||
| #endif | ||||
|  | ||||
|         vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); | ||||
|  | ||||
|         device->fp16 = device->fp16 && vk12_features.shaderFloat16; | ||||
|  | ||||
|         device->pipeline_robustness = pl_robustness_features.pipelineRobustness; | ||||
|  | ||||
|         if (coopmat2_support) { | ||||
| #if defined(VK_NV_cooperative_matrix2) | ||||
|             if (coopmat2_features.cooperativeMatrixWorkgroupScope && | ||||
|                 coopmat2_features.cooperativeMatrixFlexibleDimensions && | ||||
|                 coopmat2_features.cooperativeMatrixReductions && | ||||
|                 coopmat2_features.cooperativeMatrixConversions && | ||||
|                 coopmat2_features.cooperativeMatrixPerElementOperations && | ||||
|                 coopmat2_features.cooperativeMatrixTensorAddressing && | ||||
|                 coopmat2_features.cooperativeMatrixBlockLoads && | ||||
|                 vk12_features.bufferDeviceAddress) { | ||||
|  | ||||
|                 std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions; | ||||
|                 uint32_t count = 0; | ||||
|  | ||||
|                 PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV | ||||
|                     _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = | ||||
|                         (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) | ||||
|                         vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); | ||||
|  | ||||
|                 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); | ||||
|  | ||||
|                 VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; | ||||
|                 empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; | ||||
|                 flexible_dimensions.resize(count, empty_prop); | ||||
|  | ||||
|                 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); | ||||
|  | ||||
|                 bool found_fp16_128 = false, | ||||
|                      found_fp16_256 = false, | ||||
|                      found_fp32_128 = false, | ||||
|                      found_fp32_256 = false; | ||||
|                 // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 | ||||
|                 // with 32x16x16 and 256 with 32x32x16. | ||||
|                 for (auto &prop : flexible_dimensions) { | ||||
|                     if (prop.saturatingAccumulation == VK_FALSE && | ||||
|                         prop.scope == VK_SCOPE_WORKGROUP_KHR && | ||||
|                         prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && | ||||
|                         prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { | ||||
|  | ||||
|                         if (prop.workgroupInvocations == 128 && | ||||
|                             prop.MGranularity <= 32 && | ||||
|                             prop.NGranularity <= 16 && | ||||
|                             prop.KGranularity <= 16) { | ||||
|                             if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && | ||||
|                                 prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { | ||||
|                                 found_fp16_128 = true; | ||||
|                             } | ||||
|                             if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && | ||||
|                                 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { | ||||
|                                 found_fp32_128 = true; | ||||
|                             } | ||||
|                         } | ||||
|                         if (prop.workgroupInvocations == 256 && | ||||
|                             prop.MGranularity <= 32 && | ||||
|                             prop.NGranularity <= 32 && | ||||
|                             prop.KGranularity <= 16) { | ||||
|                             if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && | ||||
|                                 prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { | ||||
|                                 found_fp16_256 = true; | ||||
|                             } | ||||
|                             if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && | ||||
|                                 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { | ||||
|                                 found_fp32_256 = true; | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 if (found_fp16_128 && found_fp16_256 && | ||||
|                     found_fp32_128 && found_fp32_256 && | ||||
|                     coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { | ||||
|                     device->coopmat2 = true; | ||||
|                 } | ||||
|             } | ||||
| #endif | ||||
|         } | ||||
|  | ||||
|         if (!vk11_features.storageBuffer16BitAccess) { | ||||
|             std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; | ||||
|             throw std::runtime_error("Unsupported device"); | ||||
| @@ -2124,7 +2420,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type | ||||
|     return ctx->device->pipeline_dequant[type]; | ||||
| } | ||||
|  | ||||
| static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) { | ||||
| static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { | ||||
|     VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); | ||||
|     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { | ||||
|         return ctx->device->pipeline_matmul_f32; | ||||
| @@ -2132,14 +2428,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte | ||||
|     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { | ||||
|         return ctx->device->pipeline_matmul_f32_f16; | ||||
|     } | ||||
|     if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { | ||||
|         return ctx->device->pipeline_matmul_f16_f32.f32acc; | ||||
|     } | ||||
|     if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { | ||||
|         return ctx->device->pipeline_matmul_f16.f32acc; | ||||
|     if (prec == GGML_PREC_DEFAULT && ctx->device->coopmat2) { | ||||
|         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_matmul_f16_f32.f16acc; | ||||
|         } | ||||
|         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { | ||||
|             return ctx->device->pipeline_matmul_f16.f16acc; | ||||
|         } | ||||
|     } else { | ||||
|         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_matmul_f16_f32.f32acc; | ||||
|         } | ||||
|         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { | ||||
|             return ctx->device->pipeline_matmul_f16.f32acc; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (src1_type != GGML_TYPE_F32) { | ||||
|     if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { | ||||
|         return nullptr; | ||||
|     } | ||||
|  | ||||
| @@ -2160,6 +2465,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte | ||||
|             return nullptr; | ||||
|     } | ||||
|  | ||||
|     if (ctx->device->coopmat2) { | ||||
|         assert(src1_type == GGML_TYPE_F16); | ||||
|         return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; | ||||
|     } | ||||
|     return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; | ||||
| } | ||||
|  | ||||
| @@ -2844,6 +3153,16 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, | ||||
|         break; | ||||
|     } | ||||
|  | ||||
|     if (ctx->device->coopmat2) { | ||||
|         if ((m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) { | ||||
|             return aligned ? mmp->a_l : mmp->l; | ||||
|         } | ||||
|         if ((m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) { | ||||
|             return aligned ? mmp->a_m : mmp->m; | ||||
|         } | ||||
|         return aligned ? mmp->a_s : mmp->s; | ||||
|     } | ||||
|  | ||||
|     if (m <= 32 || n <= 32) { | ||||
|         return aligned ? mmp->a_s : mmp->s; | ||||
|     } | ||||
| @@ -3008,18 +3327,20 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub | ||||
|     } | ||||
|  | ||||
|     const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); | ||||
|     const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); | ||||
|     // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf | ||||
|     const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || | ||||
|                               !ggml_vk_dim01_contiguous(src1); | ||||
|  | ||||
|     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; | ||||
|  | ||||
|     vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type); | ||||
|     vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); | ||||
|  | ||||
|     const bool qx_needs_dequant = mmp == nullptr || x_non_contig; | ||||
|     const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; | ||||
|  | ||||
|     if (qx_needs_dequant) { | ||||
|         // Fall back to dequant + f16 mulmat | ||||
|         mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16); | ||||
|         mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); | ||||
|     } | ||||
|  | ||||
|     // Not implemented | ||||
| @@ -3930,6 +4251,167 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { | ||||
|     VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; | ||||
|     std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; | ||||
|     std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; | ||||
|     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; | ||||
|     std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); | ||||
|  | ||||
|     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne) | ||||
|     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb) | ||||
|     GGML_TENSOR_LOCALS(int64_t, nek, k,   ne) | ||||
|     GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb) | ||||
|     GGML_TENSOR_LOCALS(int64_t, nev, v,   ne) | ||||
|     GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb) | ||||
|     GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne) | ||||
|     GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb) | ||||
|  | ||||
|     const uint32_t nem1 = mask ? mask->ne[1] : 0; | ||||
|     const uint32_t nbm1 = mask ? mask->nb[1] : 0; | ||||
|  | ||||
|     const uint32_t D = neq0; | ||||
|     const uint32_t N = neq1; | ||||
|     const uint32_t KV = nek1; | ||||
|  | ||||
|     GGML_ASSERT(ne0 == D); | ||||
|     GGML_ASSERT(ne2 == N); | ||||
|  | ||||
|     // input tensor rows must be contiguous | ||||
|     GGML_ASSERT(nbq0 == ggml_type_size(q->type)); | ||||
|     GGML_ASSERT(nbk0 == ggml_type_size(k->type)); | ||||
|     GGML_ASSERT(nbv0 == ggml_type_size(v->type)); | ||||
|  | ||||
|     GGML_ASSERT(neq0 == D); | ||||
|     GGML_ASSERT(nek0 == D); | ||||
|     GGML_ASSERT(nev0 == D); | ||||
|  | ||||
|     GGML_ASSERT(neq1 == N); | ||||
|     GGML_ASSERT(nev0 == D); | ||||
|  | ||||
|     GGML_ASSERT(nev1 == nek1); | ||||
|  | ||||
|     // dst cannot be transposed or permuted | ||||
|     GGML_ASSERT(nb0 == sizeof(float)); | ||||
|     GGML_ASSERT(nb0 <= nb1); | ||||
|     GGML_ASSERT(nb1 <= nb2); | ||||
|     GGML_ASSERT(nb2 <= nb3); | ||||
|  | ||||
|     assert(dst->type == GGML_TYPE_F32); | ||||
|     assert(q->type == GGML_TYPE_F32); | ||||
|     assert(k->type == v->type); | ||||
|  | ||||
|     vk_pipeline *pipelines; | ||||
|     // XXX TODO other backends may be changing accumulator precision to default to f32 soon | ||||
|     bool f32acc = dst->op_params[3] == GGML_PREC_F32; | ||||
|     bool small_rows = N <= flash_attention_num_small_rows; | ||||
|     switch (D) { | ||||
|     case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; | ||||
|     case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; | ||||
|     case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; | ||||
|     case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; | ||||
|     case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; | ||||
|     case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; | ||||
|     default: | ||||
|         assert(!"unsupported D value"); | ||||
|         return; | ||||
|     } | ||||
|     assert(pipelines); | ||||
|  | ||||
|     bool aligned = (KV % pipelines[1]->align) == 0; | ||||
|     vk_pipeline pipeline = pipelines[aligned]; | ||||
|     assert(pipeline); | ||||
|  | ||||
|     if (dryrun) { | ||||
|         // Request descriptor sets | ||||
|         ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     float scale         = 1.0f; | ||||
|     float max_bias      = 0.0f; | ||||
|     float logit_softcap = 0.0f; | ||||
|  | ||||
|     memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float)); | ||||
|     memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float)); | ||||
|     memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); | ||||
|  | ||||
|     if (logit_softcap != 0) { | ||||
|         scale /= logit_softcap; | ||||
|     } | ||||
|  | ||||
|     const uint32_t n_head_kv   = neq2; | ||||
|     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); | ||||
|     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||
|     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||
|  | ||||
|     ggml_vk_sync_buffers(subctx); | ||||
|  | ||||
|     vk_buffer d_Q, d_K, d_V, d_D, d_M; | ||||
|     uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset; | ||||
|  | ||||
|     bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; | ||||
|  | ||||
|     if (ctx->device->uma) { | ||||
|         ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); | ||||
|         ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); | ||||
|         ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); | ||||
|         ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); | ||||
|         Q_uma = d_Q != nullptr; | ||||
|         K_uma = d_K != nullptr; | ||||
|         V_uma = d_V != nullptr; | ||||
|         D_uma = d_D != nullptr; | ||||
|         if (mask) { | ||||
|             ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); | ||||
|             M_uma = d_M != nullptr; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|  | ||||
|     ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; | ||||
|  | ||||
|     if (!Q_uma) { | ||||
|         d_Q = q_buf_ctx->dev_buffer; | ||||
|         q_buf_offset = vk_tensor_offset(q) + q->view_offs; | ||||
|     } | ||||
|     if (!K_uma) { | ||||
|         d_K = k_buf_ctx->dev_buffer; | ||||
|         k_buf_offset = vk_tensor_offset(k) + k->view_offs; | ||||
|     } | ||||
|     if (!V_uma) { | ||||
|         d_V = v_buf_ctx->dev_buffer; | ||||
|         v_buf_offset = vk_tensor_offset(v) + v->view_offs; | ||||
|     } | ||||
|     if (!D_uma) { | ||||
|         d_D = d_buf_ctx->dev_buffer; | ||||
|         d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; | ||||
|     } | ||||
|  | ||||
|     if (!M_uma) { | ||||
|         d_M = d_Q; | ||||
|         m_buf_offset = q_buf_offset; | ||||
|         if (mask) { | ||||
|             ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; | ||||
|             d_M = m_buf_ctx->dev_buffer; | ||||
|             m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; | ||||
|     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, | ||||
|                                 { | ||||
|                                     vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, | ||||
|                                     vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, | ||||
|                                     vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, | ||||
|                                     vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, | ||||
|                                     vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, | ||||
|                                 }, | ||||
|                                 sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); | ||||
| } | ||||
|  | ||||
| static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { | ||||
|     switch (op) { | ||||
|     case GGML_OP_GET_ROWS: | ||||
| @@ -5044,16 +5526,16 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t | ||||
|     ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); | ||||
|  | ||||
|     vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); | ||||
|     ggml_vk_ctx_begin(ctx->device, subctx); | ||||
|     for (size_t i = 0; i < num_it; i++) { | ||||
|         ggml_vk_ctx_begin(ctx->device, subctx); | ||||
|         ggml_vk_matmul( | ||||
|             ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), | ||||
|             m, n, k, | ||||
|             k, k, m, k*m, k*n, m*n, | ||||
|             split_k, batch, batch, batch, 1, 1 | ||||
|         ); | ||||
|         ggml_vk_ctx_end(subctx); | ||||
|     } | ||||
|     ggml_vk_ctx_end(subctx); | ||||
|  | ||||
|     auto begin = std::chrono::high_resolution_clock::now(); | ||||
|     ggml_vk_submit(subctx, ctx->fence); | ||||
| @@ -5391,16 +5873,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, | ||||
|     ggml_vk_buffer_write(y_buf, 0, y, y_sz); | ||||
|  | ||||
|     vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); | ||||
|     ggml_vk_ctx_begin(ctx->device, subctx); | ||||
|     for (size_t i = 0; i < num_it; i++) { | ||||
|         ggml_vk_ctx_begin(ctx->device, subctx); | ||||
|         ggml_vk_matmul( | ||||
|             ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), | ||||
|             m, n, k, | ||||
|             k, k, m, k*m, k*n, m*n, | ||||
|             split_k, batch, batch, batch, 1, 1 | ||||
|         ); | ||||
|         ggml_vk_ctx_end(subctx); | ||||
|     } | ||||
|     ggml_vk_ctx_end(subctx); | ||||
|  | ||||
|     auto begin = std::chrono::high_resolution_clock::now(); | ||||
|  | ||||
| @@ -5621,7 +6103,8 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { | ||||
|         4096, 512, 11008, | ||||
|         32000, 512, 4096, | ||||
|     }; | ||||
|     const size_t num_it = 1; | ||||
|     const size_t num_it = 100; | ||||
|  | ||||
|     for (size_t i = 0; i < vals.size(); i += 3) { | ||||
|         ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); | ||||
|         ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); | ||||
| @@ -5676,6 +6159,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     const ggml_tensor * src0 = node->src[0]; | ||||
|     const ggml_tensor * src1 = node->src[1]; | ||||
|     const ggml_tensor * src2 = node->src[2]; | ||||
|     const ggml_tensor * src3 = node->src[3]; | ||||
|  | ||||
|     switch (node->op) { | ||||
|     // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor | ||||
| @@ -5728,6 +6212,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|     case GGML_OP_POOL_2D: | ||||
|     case GGML_OP_LEAKY_RELU: | ||||
|     case GGML_OP_FLASH_ATTN_EXT: | ||||
|         break; | ||||
|     default: | ||||
|         std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; | ||||
| @@ -5920,6 +6405,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_MUL_MAT_ID: | ||||
|         ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|  | ||||
|     case GGML_OP_FLASH_ATTN_EXT: | ||||
|         ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     default: | ||||
|         return false; | ||||
| @@ -6020,6 +6510,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * | ||||
|         break; | ||||
|     case GGML_OP_MUL_MAT: | ||||
|     case GGML_OP_MUL_MAT_ID: | ||||
|     case GGML_OP_FLASH_ATTN_EXT: | ||||
|         buf = tensor->buffer; | ||||
|  | ||||
|         break; | ||||
| @@ -6751,6 +7242,57 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|  | ||||
|                 return true; | ||||
|             } break; | ||||
|         case GGML_OP_FLASH_ATTN_EXT: | ||||
|             { | ||||
|                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; | ||||
|                 if (!ggml_vk_get_device(ctx->device)->coopmat2) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 switch (op->src[0]->ne[0]) { | ||||
|                 case 64: | ||||
|                 case 80: | ||||
|                 case 96: | ||||
|                 case 112: | ||||
|                 case 128: | ||||
|                 case 256: | ||||
|                     break; | ||||
|                 default: | ||||
|                     return false; | ||||
|                 } | ||||
|                 if (op->src[0]->type != GGML_TYPE_F32) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 if (op->type != GGML_TYPE_F32) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 // It's straightforward to support different K/V dequant, but would | ||||
|                 // significantly increase the number of pipelines | ||||
|                 if (op->src[1]->type != op->src[2]->type) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 switch (op->src[1]->type) { | ||||
|                 case GGML_TYPE_F16: | ||||
|                 case GGML_TYPE_Q4_0: | ||||
|                 case GGML_TYPE_Q4_1: | ||||
|                 case GGML_TYPE_Q5_0: | ||||
|                 case GGML_TYPE_Q5_1: | ||||
|                 case GGML_TYPE_Q8_0: | ||||
|                 // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently | ||||
|                 //case GGML_TYPE_Q2_K: | ||||
|                 //case GGML_TYPE_Q3_K: | ||||
|                 //case GGML_TYPE_Q4_K: | ||||
|                 //case GGML_TYPE_Q5_K: | ||||
|                 //case GGML_TYPE_Q6_K: | ||||
|                 case GGML_TYPE_IQ4_NL: | ||||
|                     break; | ||||
|                 default: | ||||
|                     return false; | ||||
|                 } | ||||
|                 return true; | ||||
|             } | ||||
|         case GGML_OP_GET_ROWS: | ||||
|             { | ||||
|                 switch (op->src[0]->type) { | ||||
| @@ -7065,6 +7607,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { | ||||
|     ggml_tensor * src0 = tensor->src[0]; | ||||
|     ggml_tensor * src1 = tensor->src[1]; | ||||
|     ggml_tensor * src2 = tensor->src[2]; | ||||
|     ggml_tensor * src3 = tensor->src[3]; | ||||
|  | ||||
|     struct ggml_init_params iparams = { | ||||
|         /*.mem_size   =*/ 2ul*1024ul*1024ul*1024ul, | ||||
| @@ -7077,15 +7620,18 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { | ||||
|     struct ggml_tensor * src0_clone = nullptr; | ||||
|     struct ggml_tensor * src1_clone = nullptr; | ||||
|     struct ggml_tensor * src2_clone = nullptr; | ||||
|     struct ggml_tensor * src3_clone = nullptr; | ||||
|     struct ggml_tensor * tensor_clone = nullptr; | ||||
|  | ||||
|     size_t src0_size; | ||||
|     size_t src1_size; | ||||
|     size_t src2_size; | ||||
|     size_t src3_size; | ||||
|  | ||||
|     void * src0_buffer = nullptr; | ||||
|     void * src1_buffer = nullptr; | ||||
|     void * src2_buffer = nullptr; | ||||
|     void * src3_buffer = nullptr; | ||||
|  | ||||
|     if (src0 != nullptr) { | ||||
|         src0_clone = ggml_dup_tensor(ggml_ctx, src0); | ||||
| @@ -7213,8 +7759,53 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { | ||||
|             ggml_vk_print_tensor(src2, "src2"); | ||||
|         } | ||||
|     } | ||||
|     if (src3 != nullptr) { | ||||
|         src3_clone = ggml_dup_tensor(ggml_ctx, src3); | ||||
|  | ||||
|     if (tensor->op == GGML_OP_MUL_MAT) { | ||||
|         src3_size = ggml_nbytes(src3); | ||||
|  | ||||
|         src3_buffer = malloc(src3_size); | ||||
|         src3_clone->data = src3_buffer; | ||||
|         if (ggml_backend_buffer_is_host(src3->buffer)) { | ||||
|             memcpy(src3_clone->data, src3->data, src3_size); | ||||
|             memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); | ||||
|         } else if (ggml_backend_buffer_is_vk(src3->buffer)) { | ||||
|             ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; | ||||
|             vk_buffer& buffer_gpu = buf_ctx->dev_buffer; | ||||
|             uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; | ||||
|             if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { | ||||
|                 for (int i3 = 0; i3 < src3->ne[3]; i3++) { | ||||
|                     for (int i2 = 0; i2 < src3->ne[2]; i2++) { | ||||
|                         const int idx = i3*src3->ne[2] + i2; | ||||
|                         ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 src3_clone->nb[0] = src3->nb[0]; | ||||
|                 src3_clone->nb[1] = src3->nb[1]; | ||||
|                 for (int i = 2; i < GGML_MAX_DIMS; i++) { | ||||
|                     src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; | ||||
|                 } | ||||
|             } else { | ||||
|                 if (offset + src3_size >= buffer_gpu->size) { | ||||
|                     src3_size = buffer_gpu->size - offset; | ||||
|                 } | ||||
|                 ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); | ||||
|                 memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); | ||||
|             } | ||||
|         } else { | ||||
|             GGML_ABORT("fatal error"); | ||||
|         } | ||||
|  | ||||
|         if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { | ||||
|             ggml_vk_print_tensor(src3, "src3"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { | ||||
|         const float *params = (const float *)tensor->op_params; | ||||
|         tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); | ||||
|     } else if (tensor->op == GGML_OP_MUL_MAT) { | ||||
|         tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); | ||||
|     } else if (tensor->op == GGML_OP_MUL_MAT_ID) { | ||||
|         tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); | ||||
|   | ||||
| @@ -1,7 +1,9 @@ | ||||
| find_package (Threads REQUIRED) | ||||
| find_package(Vulkan COMPONENTS glslc REQUIRED) | ||||
|  | ||||
| set(TARGET vulkan-shaders-gen) | ||||
| add_executable(${TARGET} vulkan-shaders-gen.cpp) | ||||
| install(TARGETS ${TARGET} RUNTIME) | ||||
| target_compile_features(${TARGET} PRIVATE cxx_std_17) | ||||
| target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) | ||||
| target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) | ||||
|   | ||||
							
								
								
									
										305
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										305
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,305 @@ | ||||
|  | ||||
| #include "types.comp" | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { | ||||
|    block_q4_0_packed16 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|     float16_t ret = (float16_t(qs) - float16_t(8)) * d; | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { | ||||
|    block_q4_1 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const float16_t m = bl.block.m; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|     float16_t ret = float16_t(qs) * d + m; | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { | ||||
|    block_q5_0 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
|  | ||||
|     const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; | ||||
|     const uint qh = ((uint_qh >> idx) << 4) & 0x10; | ||||
|  | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|  | ||||
|     float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { | ||||
|    block_q5_1 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const float16_t m = bl.block.m; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
|  | ||||
|     const uint uint_qh = bl.block.qh; | ||||
|     const uint qh = ((uint_qh >> idx) << 4) & 0x10; | ||||
|  | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|  | ||||
|     float16_t ret = float16_t(qs | qh) * d + m; | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { | ||||
|    block_q8_0_packed16 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     // Load 16b and select the byte for this element | ||||
|     int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; | ||||
|     float16_t ret = float16_t(qs) * d; | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { | ||||
|    block_q2_K block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const f16vec2 d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint qsi = (iqs / 128) * 32 + (iqs % 32);     // 0..31 | ||||
|     const uint scalesi = iqs / 16;                      // 0..15 | ||||
|     const uint qsshift = ((iqs % 128) / 32) * 2;        // 0,2,4,6 | ||||
|  | ||||
|     uint32_t qs = bl.block.qs[qsi]; | ||||
|     const uint scales = bl.block.scales[scalesi]; | ||||
|     float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { | ||||
|    block_q3_K block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint n = iqs / 128;                    // 0,1 | ||||
|     const uint qsi = n * 32 + (iqs % 32);        // 0..63 | ||||
|     const uint hmi =          (iqs % 32);        // 0..31 | ||||
|     const uint j = (iqs % 128) / 8;              // 0..15 | ||||
|     const uint is = iqs / 16;                    // 0..15 | ||||
|     const uint halfsplit = ((iqs % 128) / 32);   // 0,1,2,3 | ||||
|     const uint qsshift = halfsplit * 2;          // 0,2,4,6 | ||||
|     const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128 | ||||
|  | ||||
|     uint32_t scaleidx0 = (is < 8) ? is : (is-8); | ||||
|     uint32_t scaleidx0shift = (is < 8) ? 0 : 4; | ||||
|     uint32_t scaleidx1 = is + 8 - (is/4)*4; | ||||
|     uint32_t scaleidx1shift = (is/4)*2; | ||||
|  | ||||
|     const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); | ||||
|  | ||||
|     const float16_t dl = bl.block.d * float16_t(us - 32); | ||||
|  | ||||
|     float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi    ] >> qsshift) & 3) - (((bl.block.hmask[hmi    ] & m) != 0) ? 0 : 4)); | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { | ||||
|    block_q4_K block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint n = iqs / 64;                   // 0,1,2,3 | ||||
|     const uint b = (iqs % 64) / 32;            // 0,1 | ||||
|     const uint is = (idx & 0xE0) >> 5;         // 0..7 | ||||
|     const uint qsi = n * 32 + (iqs % 32);      // 0..127 | ||||
|  | ||||
|     const f16vec2 loadd = bl.block.d; | ||||
|  | ||||
|     uint32_t sc; | ||||
|     uint32_t mbyte; | ||||
|  | ||||
|     uint32_t scidx0 = (is < 4) ? is : (is + 4); | ||||
|     uint32_t scidx1 = (is < 4) ? is : (is - 4); | ||||
|     uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t scidxshift1 = (is < 4) ? 0 : 2; | ||||
|     uint32_t mbidx0 = is + 4; | ||||
|     uint32_t mbidx1 = (is < 4) ? is + 4 : is; | ||||
|     uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; | ||||
|     uint32_t mbidxshift0 = (is < 4) ? 0 : 4; | ||||
|     uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t mbidxshift1 = (is < 4) ? 0 : 2; | ||||
|  | ||||
|     sc    = uint8_t((bl.block.scales[scidx0] & 0xF)                         | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); | ||||
|     mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); | ||||
|  | ||||
|     const float16_t d = loadd.x * float16_t(sc); | ||||
|     const float16_t m = loadd.y * float16_t(mbyte); | ||||
|  | ||||
|     uint32_t dmask = 0xF << (b * 4); | ||||
|  | ||||
|     float16_t ret = d * float16_t((bl.block.qs[qsi    ] & dmask) >> (b * 4)) - m; | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { | ||||
|    block_q5_K block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint n = iqs / 64;                   // 0,1,2,3 | ||||
|     const uint b = (iqs % 64) / 32;            // 0,1 | ||||
|     const uint is = (idx & 0xE0) >> 5;         // 0..7 | ||||
|     const uint qsi = n * 32 + (iqs % 32);      // 0..127 | ||||
|     const uint qhi = (iqs % 32);               // 0..31 | ||||
|  | ||||
|     const uint8_t hm = uint8_t(1 << (iqs / 32)); | ||||
|  | ||||
|     const f16vec2 loadd = bl.block.d; | ||||
|  | ||||
|     uint32_t sc; | ||||
|     uint32_t mbyte; | ||||
|  | ||||
|     uint32_t scidx0 = (is < 4) ? is : (is + 4); | ||||
|     uint32_t scidx1 = (is < 4) ? is : (is - 4); | ||||
|     uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t scidxshift1 = (is < 4) ? 0 : 2; | ||||
|     uint32_t mbidx0 = is + 4; | ||||
|     uint32_t mbidx1 = (is < 4) ? is + 4 : is; | ||||
|     uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; | ||||
|     uint32_t mbidxshift0 = (is < 4) ? 0 : 4; | ||||
|     uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||
|     uint32_t mbidxshift1 = (is < 4) ? 0 : 2; | ||||
|  | ||||
|     sc    = uint8_t((bl.block.scales[scidx0] & 0xF)                         | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); | ||||
|     mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); | ||||
|  | ||||
|     const float16_t d = loadd.x * float16_t(sc); | ||||
|     const float16_t m = loadd.y * float16_t(mbyte); | ||||
|  | ||||
|     uint32_t dmask = 0xF << (b * 4); | ||||
|  | ||||
|     float16_t ret = d * (float16_t((bl.block.qs[qsi    ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi    ] & hm) != 0 ? 16 : 0)) - m; | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { | ||||
|    block_q6_K block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint n = iqs / 128;                   // 0,1 | ||||
|     const uint b = (iqs % 128) / 64;            // 0,1 | ||||
|     const uint is_b = (iqs % 32) / 16;          // 0,1 | ||||
|     const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6 | ||||
|     const uint is = 8 * n + qhshift + is_b;     // 0..15 | ||||
|     const uint qsi = n * 64 + (iqs % 64);       // 0..127 | ||||
|     const uint qhi = n * 32 + (iqs % 32);       // 0..63 | ||||
|  | ||||
|     const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); | ||||
|  | ||||
|     float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi    ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi    ] >> qhshift) & 3) << 4)) - 32); | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| #if defined(DATA_A_IQ4_NL) | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { | ||||
|    block_iq4_nl block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx & 0xF; | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = bl.block.qs[iqs]; | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|     float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; | ||||
|     return ret; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| #if defined(DATA_A_Q4_0) | ||||
| #define dequantFuncA dequantFuncQ4_0 | ||||
| #elif defined(DATA_A_Q4_1) | ||||
| #define dequantFuncA dequantFuncQ4_1 | ||||
| #elif defined(DATA_A_Q5_0) | ||||
| #define dequantFuncA dequantFuncQ5_0 | ||||
| #elif defined(DATA_A_Q5_1) | ||||
| #define dequantFuncA dequantFuncQ5_1 | ||||
| #elif defined(DATA_A_Q8_0) | ||||
| #define dequantFuncA dequantFuncQ8_0 | ||||
| #elif defined(DATA_A_Q2_K) | ||||
| #define dequantFuncA dequantFuncQ2_K | ||||
| #elif defined(DATA_A_Q3_K) | ||||
| #define dequantFuncA dequantFuncQ3_K | ||||
| #elif defined(DATA_A_Q4_K) | ||||
| #define dequantFuncA dequantFuncQ4_K | ||||
| #elif defined(DATA_A_Q5_K) | ||||
| #define dequantFuncA dequantFuncQ5_K | ||||
| #elif defined(DATA_A_Q6_K) | ||||
| #define dequantFuncA dequantFuncQ6_K | ||||
| #elif defined(DATA_A_IQ4_NL) | ||||
| #define dequantFuncA dequantFuncIQ4_NL | ||||
| #endif | ||||
							
								
								
									
										289
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										289
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,289 @@ | ||||
| #version 450 | ||||
|  | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
|  | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require | ||||
|  | ||||
| #extension GL_KHR_memory_scope_semantics : enable | ||||
| #extension GL_KHR_cooperative_matrix : enable | ||||
| #extension GL_NV_cooperative_matrix2 : enable | ||||
| #extension GL_EXT_buffer_reference : enable | ||||
| #extension GL_KHR_shader_subgroup_ballot : enable | ||||
| #extension GL_KHR_shader_subgroup_vote : enable | ||||
| #extension GL_EXT_null_initializer : enable | ||||
|  | ||||
| #include "types.comp" | ||||
| #include "dequant_funcs_cm2.comp" | ||||
|  | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| layout (constant_id = 1) const uint32_t Br = 32; | ||||
| layout (constant_id = 2) const uint32_t Bc = 32; | ||||
| layout (constant_id = 3) const uint32_t D = 32; | ||||
| layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; | ||||
|  | ||||
| layout (push_constant) uniform parameter { | ||||
|     uint32_t N; | ||||
|     uint32_t KV; | ||||
|  | ||||
|     uint32_t ne1; | ||||
|     uint32_t ne2; | ||||
|     uint32_t ne3; | ||||
|  | ||||
|     uint32_t neq2; | ||||
|     uint32_t neq3; | ||||
|     uint32_t nek2; | ||||
|     uint32_t nek3; | ||||
|     uint32_t nev2; | ||||
|     uint32_t nev3; | ||||
|     uint32_t nem1; | ||||
|  | ||||
|     uint32_t nb02; | ||||
|     uint32_t nb03; | ||||
|     uint32_t nb12; | ||||
|     uint32_t nb13; | ||||
|     uint32_t nb22; | ||||
|     uint32_t nb23; | ||||
|     uint32_t nb31; | ||||
|  | ||||
|     float scale; | ||||
|     float max_bias; | ||||
|     float logit_softcap; | ||||
|  | ||||
|     uint32_t mask; | ||||
|     uint32_t n_head_log2; | ||||
|     float m0; | ||||
|     float m1; | ||||
| } p; | ||||
|  | ||||
| layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; | ||||
| layout (binding = 1) readonly buffer K {uint8_t data_k[];}; | ||||
| layout (binding = 2) readonly buffer V {uint8_t data_v[];}; | ||||
| layout (binding = 3) readonly buffer M {uint8_t data_m[];}; | ||||
| layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; | ||||
|  | ||||
| #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) | ||||
|  | ||||
| ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | ||||
|     return max(x, y); | ||||
| } | ||||
|  | ||||
| ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | ||||
|     return x; | ||||
| } | ||||
|  | ||||
| // Replace matrix elements >= numRows or numCols with 'replace' | ||||
| ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { | ||||
|     if (row >= numRows || col >= numCols) { | ||||
|         return replace; | ||||
|     } | ||||
|     return elem; | ||||
| } | ||||
|  | ||||
| ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) | ||||
| { | ||||
|     return exp(elem); | ||||
| } | ||||
|  | ||||
| ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) | ||||
| { | ||||
|     return max(elem0, elem1); | ||||
| } | ||||
|  | ||||
| #if defined(BLOCK_SIZE) | ||||
| #define DECODEFUNC , DEQUANTFUNC | ||||
| #else | ||||
| #define DECODEFUNC | ||||
| #endif | ||||
|  | ||||
| void main() { | ||||
| #if defined(DATA_A_IQ4_NL) | ||||
|     init_iq4nl_shmem(); | ||||
| #endif | ||||
|  | ||||
|     const uint32_t N = p.N; | ||||
|     const uint32_t KV = p.KV; | ||||
|  | ||||
|     const uint32_t Tr = CEIL_DIV(N, Br); | ||||
|     const uint32_t Tc = CEIL_DIV(KV, Bc); | ||||
|  | ||||
|     const uint32_t i = gl_WorkGroupID.x; | ||||
|  | ||||
|     const uint32_t iq2 = gl_WorkGroupID.y; | ||||
|     const uint32_t iq3 = gl_WorkGroupID.z; | ||||
|  | ||||
|     // broadcast factors | ||||
|     const uint32_t rk2 = p.neq2/p.nek2; | ||||
|     const uint32_t rk3 = p.neq3/p.nek3; | ||||
|  | ||||
|     const uint32_t rv2 = p.neq2/p.nev2; | ||||
|     const uint32_t rv3 = p.neq3/p.nev3; | ||||
|  | ||||
|     // k indices | ||||
|     const uint32_t ik3 = iq3 / rk3; | ||||
|     const uint32_t ik2 = iq2 / rk2; | ||||
|  | ||||
|     // v indices | ||||
|     const uint32_t iv3 = iq3 / rv3; | ||||
|     const uint32_t iv2 = iq2 / rv2; | ||||
|  | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); | ||||
|     tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); | ||||
|  | ||||
|     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); | ||||
|  | ||||
| #if defined(BLOCK_SIZE) | ||||
|     tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); | ||||
|     tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); | ||||
| #endif | ||||
|  | ||||
|     tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); | ||||
|     tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); | ||||
|     tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); | ||||
|  | ||||
|     coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q; | ||||
|     coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16; | ||||
|  | ||||
|     uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; | ||||
|     coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); | ||||
|  | ||||
|     Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q); | ||||
|     Qf16 *= float16_t(p.scale); | ||||
|  | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0); | ||||
|  | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; | ||||
|  | ||||
|     L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); | ||||
|     M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0); | ||||
|  | ||||
|     ACC_TYPE slope = ACC_TYPE(1.0); | ||||
|  | ||||
|     // ALiBi | ||||
|     if (p.max_bias > 0.0f) { | ||||
|         const uint32_t h = iq2; | ||||
|  | ||||
|         const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); | ||||
|         const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); | ||||
|  | ||||
|         slope = pow(base, ACC_TYPE(exph)); | ||||
|     } | ||||
|  | ||||
|     [[dont_unroll]] | ||||
|     for (uint32_t j = 0; j < Tc; ++j) { | ||||
|  | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); | ||||
|  | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T; | ||||
|  | ||||
|         uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; | ||||
|         coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); | ||||
|         S = coopMatMulAdd(Qf16, K_T, S); | ||||
|  | ||||
|         if (p.logit_softcap != 0.0f) { | ||||
|             [[unroll]] | ||||
|             for (int k = 0; k < S.length(); ++k) { | ||||
|                 S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (p.mask != 0) { | ||||
|             tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); | ||||
|  | ||||
|             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; | ||||
|  | ||||
|             coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); | ||||
|  | ||||
|             S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); | ||||
|         } | ||||
|  | ||||
|         // Clear padding elements to -inf, so they don't contribute to rowmax | ||||
|         if (Clamp != 0 && | ||||
|             ((j + 1) * Bc > KV || | ||||
|              (i + 1) * Br > N)) { | ||||
|  | ||||
|             uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br; | ||||
|             uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; | ||||
|  | ||||
|             coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); | ||||
|         } | ||||
|  | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM; | ||||
|  | ||||
|         coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); | ||||
|  | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M; | ||||
|  | ||||
|         // M = max(rowmax, Mold) | ||||
|         // P = e^(S - M) | ||||
|         // eM = e^(Mold - M) | ||||
|         coopMatPerElementNV(M, rowmax, Max, Mold); | ||||
|         coopMatPerElementNV(P, S - M, Exp); | ||||
|         coopMatPerElementNV(eM, Mold - M, Exp); | ||||
|  | ||||
|         // Clear padding elements to 0, so they don't contribute to rowsum | ||||
|         if (Clamp != 0 && | ||||
|             ((j + 1) * Bc > KV || | ||||
|              (i + 1) * Br > N)) { | ||||
|  | ||||
|             uint R = ((i + 1) * Br >  N) ?  (N % Br) : Br; | ||||
|             uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; | ||||
|  | ||||
|             coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); | ||||
|         } | ||||
|  | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); | ||||
|  | ||||
|         // compute rowsum by multiplying by matrix of all ones. | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); | ||||
|  | ||||
|         rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); | ||||
|         rowsum = coopMatMulAdd(P_A, One, rowsum); | ||||
|  | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V; | ||||
|         uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; | ||||
|         coopMatLoadTensorNV(V,  data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); | ||||
|  | ||||
|         L = eM*L + rowsum; | ||||
|  | ||||
|         // This is the "diagonal" matrix in the paper, but since we do componentwise | ||||
|         // multiply rather than matrix multiply it has the diagonal element smeared | ||||
|         // across the row | ||||
|         coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag; | ||||
|  | ||||
|         // resize eM by using smear/reduce | ||||
|         coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); | ||||
|  | ||||
|         O = eMdiag * O; | ||||
|  | ||||
|         O = coopMatMulAdd(P_A, V, O); | ||||
|     } | ||||
|  | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag; | ||||
|  | ||||
|     // resize L by using smear/reduce | ||||
|     coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); | ||||
|  | ||||
|     [[unroll]] | ||||
|     for (int k = 0; k < Ldiag.length(); ++k) { | ||||
|         Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; | ||||
|     } | ||||
|  | ||||
|     O = Ldiag*O; | ||||
|  | ||||
|     tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); | ||||
|  | ||||
|     // permute dimensions | ||||
|     tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); | ||||
|     uint32_t o_offset = iq3*p.ne2*p.ne1; | ||||
|  | ||||
|     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); | ||||
|     coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); | ||||
| } | ||||
							
								
								
									
										328
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										328
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,328 @@ | ||||
| #version 450 | ||||
|  | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
|  | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require | ||||
|  | ||||
| #extension GL_KHR_memory_scope_semantics : enable | ||||
| #extension GL_KHR_cooperative_matrix : enable | ||||
| #extension GL_NV_cooperative_matrix2 : enable | ||||
| #extension GL_EXT_buffer_reference : enable | ||||
| #extension GL_KHR_shader_subgroup_ballot : enable | ||||
| #extension GL_KHR_shader_subgroup_vote : enable | ||||
|  | ||||
| #include "types.comp" | ||||
|  | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| layout (constant_id = 1) const uint BM = 64; | ||||
| layout (constant_id = 2) const uint BN = 64; | ||||
| layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant | ||||
|  | ||||
| layout (push_constant) uniform parameter | ||||
| { | ||||
|     uint M; | ||||
|     uint N; | ||||
|     uint K; | ||||
|     uint stride_a; | ||||
|     uint stride_b; | ||||
|     uint stride_d; | ||||
|  | ||||
|     uint batch_stride_a; | ||||
|     uint batch_stride_b; | ||||
|     uint batch_stride_d; | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
|     uint nei0; | ||||
|     uint nei1; | ||||
|     uint nbi1; | ||||
|     uint ne11; | ||||
| #else | ||||
|     uint k_split; | ||||
|     uint ne02; | ||||
|     uint ne12; | ||||
|     uint broadcast2; | ||||
|     uint broadcast3; | ||||
| #endif | ||||
| } p; | ||||
|  | ||||
|  | ||||
| layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; | ||||
| layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; | ||||
| layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; | ||||
|  | ||||
| #if QUANT_K > 1 | ||||
| #define DECODEFUNCA , dequantFuncA | ||||
| #define MAT_A_TYPE float16_t | ||||
|  | ||||
| #include "dequant_funcs_cm2.comp" | ||||
|  | ||||
| #else | ||||
| #define DECODEFUNCA | ||||
| #define MAT_A_TYPE A_TYPE | ||||
| #endif | ||||
|  | ||||
| #define MAT_B_TYPE B_TYPE | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
| layout (binding = 3) readonly buffer IDS {int data_ids[];}; | ||||
|  | ||||
| shared u16vec4 row_ids[3072]; | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { | ||||
|    B_TYPE b[]; | ||||
| }; | ||||
|  | ||||
| uint _ne1; | ||||
| shared uint _ne1_sh; | ||||
|  | ||||
| B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     const uint row_i = blockCoords[0]; | ||||
|  | ||||
|     if (row_i >= _ne1) { | ||||
|         return B_TYPE(0.0); | ||||
|     } | ||||
|  | ||||
|     const u16vec4 row_idx = row_ids[row_i]; | ||||
|     B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) | ||||
| { | ||||
|     uint dr = ir * BM + r; | ||||
|     uint dc = ic * BN + c; | ||||
|  | ||||
|     if (dr < p.M && dc < _ne1) { | ||||
|         uint row_i = dc; | ||||
|         const u16vec4 row_idx = row_ids[row_i]; | ||||
|         data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; | ||||
|     } | ||||
|     return elem; | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| void main() { | ||||
| #if defined(DATA_A_IQ4_NL) | ||||
|     init_iq4nl_shmem(); | ||||
| #endif | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
|     const uint expert_idx = gl_GlobalInvocationID.z; | ||||
| #else | ||||
|     const uint batch_idx = gl_GlobalInvocationID.z; | ||||
|  | ||||
|     const uint i13 = batch_idx / p.ne12; | ||||
|     const uint i12 = batch_idx % p.ne12; | ||||
|  | ||||
|     const uint i03 = i13 / p.broadcast3; | ||||
|     const uint i02 = i12 / p.broadcast2; | ||||
|  | ||||
|     const uint batch_idx_a = i03 * p.ne02 + i02; | ||||
| #endif | ||||
|  | ||||
|     const uint blocks_m = (p.M + BM - 1) / BM; | ||||
|     const uint ir = gl_WorkGroupID.x % blocks_m; | ||||
|     const uint ik = gl_WorkGroupID.x / blocks_m; | ||||
|     const uint ic = gl_WorkGroupID.y; | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
|     // Spread the search across all elements in the first subgroup | ||||
|     if (gl_SubgroupID == 0) { | ||||
|         _ne1 = 0; | ||||
|         uint num_elements = p.nei1 * p.nei0; | ||||
|  | ||||
|         for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { | ||||
|             bool in_range = i < num_elements; | ||||
|             uint ii0 = i % p.nei0; | ||||
|             uint ii1 = i / p.nei0; | ||||
|             uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; | ||||
|             uvec4 ballot = subgroupBallot(in_range && id == expert_idx); | ||||
|             uint idx = subgroupBallotExclusiveBitCount(ballot); | ||||
|             if (in_range && id == expert_idx) { | ||||
|                 row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); | ||||
|             } | ||||
|             _ne1 += subgroupBallotBitCount(ballot); | ||||
|         } | ||||
|         _ne1_sh = _ne1; | ||||
|     } | ||||
|  | ||||
|     barrier(); | ||||
|  | ||||
|     _ne1 = _ne1_sh; | ||||
|  | ||||
|     // Workgroup has no work | ||||
|     if (ic * BN >= _ne1) return; | ||||
| #endif | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
|     uint start_k = 0; | ||||
|     const uint end_k = p.K; | ||||
| #else | ||||
|     uint start_k = ik * p.k_split; | ||||
|     const uint end_k = min(p.K, (ik + 1) * p.k_split); | ||||
| #endif | ||||
|  | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum; | ||||
|     sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0); | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
|     uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; | ||||
|     uint pos_b = 0; | ||||
| #else | ||||
|     uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; | ||||
|     uint pos_b = batch_idx * p.batch_stride_b; | ||||
| #endif | ||||
|  | ||||
|     uint stride_a = p.stride_a / QUANT_K; | ||||
|     uint stride_b = p.stride_b; | ||||
|  | ||||
|     // Hint to the compiler that values are aligned (want 16B alignment). | ||||
|     // Quants are always block-aligned, no alignment needed. | ||||
| #if ALIGNED | ||||
| #if QUANT_K == 1 | ||||
|     stride_a &= ~7; | ||||
| #endif | ||||
|     stride_b &= ~7; | ||||
| #endif | ||||
|  | ||||
|     // Create layouts for both clamped and unclamped accesses | ||||
|     tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|  | ||||
| #if QUANT_K > 1 | ||||
|     tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); | ||||
|     tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); | ||||
| #endif | ||||
|  | ||||
|     // Use end_k rather than p.K as the dimension because that's what | ||||
|     // we need to bound check against when using split_k | ||||
|     tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); | ||||
|     tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); | ||||
|     tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); | ||||
|     tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); | ||||
|     tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); | ||||
|  | ||||
|     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); | ||||
|  | ||||
| #if !defined(MUL_MAT_ID) | ||||
|     // Detect a fast path where all loads are entirely in bounds and no clamping is required | ||||
|     if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && | ||||
| #if QUANT_K == 1 | ||||
|         (stride_a % 8) == 0 && | ||||
| #endif | ||||
|         (stride_b % 8) == 0 && (start_k % 8) == 0) { | ||||
|         // Hint to the compiler that values are aligned (want 16B alignment) | ||||
|         start_k &= ~7; | ||||
|         stride_b &= ~7; | ||||
| #if QUANT_K == 1 | ||||
|         stride_a &= ~7; | ||||
| #endif | ||||
|  | ||||
|         tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); | ||||
|         tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); | ||||
|  | ||||
|         uint k_iters = (end_k - start_k + BK - 1) / BK; | ||||
|  | ||||
|         for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { | ||||
|  | ||||
|             coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; | ||||
|             coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b; | ||||
|  | ||||
|             coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|  | ||||
|             coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|  | ||||
|             sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|         } | ||||
|     } else | ||||
| #endif // !defined(MUL_MAT_ID) | ||||
|     { | ||||
|         tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); | ||||
|  | ||||
|         tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); | ||||
|  | ||||
|         tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); | ||||
|  | ||||
|         tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); | ||||
|  | ||||
|         [[dont_unroll]] | ||||
|         for (uint block_k = start_k; block_k < end_k; block_k += BK) { | ||||
|  | ||||
|             coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a; | ||||
|             coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b; | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft; | ||||
|             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft; | ||||
|  | ||||
|             // Clamping is expensive, so detect different code paths for each combination | ||||
|             // of A and B needing clamping. | ||||
|             bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; | ||||
| #ifdef MUL_MAT_ID | ||||
|             bool unclampedB = true; | ||||
| #else | ||||
|             bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; | ||||
| #endif | ||||
|             if (unclampedA && unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); | ||||
| #ifdef MUL_MAT_ID | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); | ||||
| #else | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); | ||||
| #endif | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } else if (unclampedA && !unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); | ||||
|  | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } else if (!unclampedA && unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); | ||||
| #ifdef MUL_MAT_ID | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); | ||||
| #else | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); | ||||
| #endif | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } else if (!unclampedA && !unclampedB) { | ||||
|                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); | ||||
|                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); | ||||
|  | ||||
|                 mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a); | ||||
|                 mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b); | ||||
|                 sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Convert from ACC_TYPE to D_TYPE | ||||
|     coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d; | ||||
|     mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum); | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
|     // Call callback to store each element, remapping row through shared memory | ||||
|     coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); | ||||
| #else | ||||
|     tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); | ||||
|  | ||||
|     uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; | ||||
|     coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); | ||||
| #endif | ||||
| } | ||||
| @@ -30,6 +30,8 @@ | ||||
|     #include <fcntl.h> | ||||
| #endif | ||||
|  | ||||
| #include <vulkan/vulkan_core.h> | ||||
|  | ||||
| #define ASYNCIO_CONCURRENCY 64 | ||||
|  | ||||
| std::mutex lock; | ||||
| @@ -196,15 +198,17 @@ static uint32_t compile_count = 0; | ||||
| static std::mutex compile_count_mutex; | ||||
| static std::condition_variable compile_count_cond; | ||||
|  | ||||
| void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) { | ||||
|     std::string name = _name + (fp16 ? "" : "_fp32"); | ||||
| void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { | ||||
|     std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); | ||||
|     std::string out_fname = join_paths(output_dir, name + ".spv"); | ||||
|     std::string in_path = join_paths(input_dir, in_fname); | ||||
|  | ||||
|     std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; | ||||
|  | ||||
|     #ifdef _WIN32 | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; | ||||
|     #else | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o",  out_fname}; | ||||
|         std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "-O", in_path, "-o",  out_fname}; | ||||
|     #endif | ||||
|  | ||||
|     #ifdef GGML_VULKAN_SHADER_DEBUG_INFO | ||||
| @@ -254,7 +258,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s | ||||
| } | ||||
|  | ||||
| static std::vector<std::future<void>> compiles; | ||||
| void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) { | ||||
| void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { | ||||
|     { | ||||
|         // wait until fewer than N compiles are in progress. | ||||
|         // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. | ||||
| @@ -265,15 +269,15 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const | ||||
|         } | ||||
|         compile_count++; | ||||
|     } | ||||
|     compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16)); | ||||
|     compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc)); | ||||
| } | ||||
|  | ||||
| void matmul_shaders(bool fp16, bool matmul_id) { | ||||
|     std::string load_vec = fp16 ? "8" : "4"; | ||||
|     std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4"; | ||||
|     std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4"; | ||||
| void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { | ||||
|     std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; | ||||
|     std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; | ||||
|     std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; | ||||
|  | ||||
|     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}}; | ||||
|     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; | ||||
|     std::string shader_name = "matmul"; | ||||
|  | ||||
|     if (matmul_id) { | ||||
| @@ -285,21 +289,31 @@ void matmul_shaders(bool fp16, bool matmul_id) { | ||||
|         base_dict["FLOAT16"] = "1"; | ||||
|     } | ||||
|  | ||||
|     // Shaders with f16 B_TYPE | ||||
|     string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); | ||||
|     string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); | ||||
|     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||
|  | ||||
|     string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); | ||||
|     string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); | ||||
|     std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; | ||||
|  | ||||
|     // Shaders with f16 B_TYPE | ||||
|     string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc); | ||||
|     string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
|  | ||||
|     string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
|     string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc); | ||||
|  | ||||
|     for (const auto& tname : type_names) { | ||||
|         std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|         // For unaligned, load one at a time for f32/f16, or two at a time for quants | ||||
|         std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2"; | ||||
|         std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; | ||||
|         // For aligned matmul loads | ||||
|         std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2"; | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16); | ||||
|         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; | ||||
|  | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); | ||||
|         string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
|  | ||||
|         if (tname != "f16" && tname != "f32") { | ||||
|             string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); | ||||
|             string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -307,11 +321,50 @@ void process_shaders() { | ||||
|     std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; | ||||
|     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; | ||||
|  | ||||
|     // matmul | ||||
|     for (const auto& fp16 : {false, true}) { | ||||
|         matmul_shaders(fp16, false); | ||||
|         matmul_shaders(fp16, true); | ||||
|         for (const auto& matmul_id : {false, true}) { | ||||
|             for (const auto& coopmat2 : {false, true}) { | ||||
|                 for (const auto& f16acc : {false, true}) { | ||||
| #if !defined(VK_NV_cooperative_matrix2) | ||||
|                     if (coopmat2) { | ||||
|                         continue; | ||||
|                     } | ||||
| #endif | ||||
|                     if (coopmat2 && !fp16) { | ||||
|                         continue; | ||||
|                     } | ||||
|                     if (!coopmat2 && f16acc) { | ||||
|                         continue; | ||||
|                     } | ||||
|                     matmul_shaders(fp16, matmul_id, coopmat2, f16acc); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
| #if defined(VK_NV_cooperative_matrix2) | ||||
|     // flash attention | ||||
|     for (const auto& f16acc : {false, true}) { | ||||
|         std::string acctype = f16acc ? "float16_t" : "float"; | ||||
|  | ||||
|         for (const auto& tname : type_names) { | ||||
|             if (tname == "f32") { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             if (tname == "f16") { | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc); | ||||
|             } else { | ||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     for (const auto& tname : type_names) { | ||||
|         // mul mat vec | ||||
|         std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz