mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	vulkan: Use coopmat2 for conv2d (#14982)
This commit is contained in:
		| @@ -3096,6 +3096,10 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|         uint32_t conv2d_SHMEM_PAD = 4; | ||||
|         bool conv2d_UNROLL = true; | ||||
|  | ||||
|         if (device->coopmat2) { | ||||
|             conv2d_SHMEM_PAD = 8; // 8 float16_t | ||||
|         } | ||||
|  | ||||
|         if (device->vendor_id == VK_VENDOR_ID_INTEL) { | ||||
|             conv2d_SHMEM_PAD = 0; | ||||
|             conv2d_UNROLL = false; | ||||
| @@ -3154,7 +3158,14 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|         std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; | ||||
|         std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; | ||||
|  | ||||
|         if (conv2d_UNROLL) { | ||||
|         if (device->coopmat2) { | ||||
|             ggml_vk_create_pipeline( | ||||
|                 device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3, | ||||
|                 sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); | ||||
|             ggml_vk_create_pipeline( | ||||
|                 device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3, | ||||
|                 sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); | ||||
|         } else if (conv2d_UNROLL) { | ||||
|             ggml_vk_create_pipeline( | ||||
|                 device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3, | ||||
|                 sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); | ||||
|   | ||||
| @@ -1,6 +1,11 @@ | ||||
| #version 450 | ||||
|  | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
| #ifdef COOPMAT2 | ||||
| #extension GL_NV_cooperative_matrix2 : enable | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||||
| #extension GL_KHR_memory_scope_semantics : enable | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_COLLECTIVES | ||||
| #    extension GL_KHR_shader_subgroup_shuffle : enable | ||||
| @@ -91,6 +96,12 @@ uint32_t n_elems_out = K * NPQ; | ||||
| // Number of blocktiles per input | ||||
| uint32_t NB_CRS = splitWork(CRS, BS_CRS); | ||||
|  | ||||
| #ifdef COOPMAT2 | ||||
| #define SHMEM_TYPE float16_t | ||||
| #else | ||||
| #define SHMEM_TYPE float | ||||
| #endif | ||||
|  | ||||
| const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; | ||||
| const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; | ||||
|  | ||||
| @@ -100,8 +111,8 @@ const uint32_t Bsh_numel = BS_CRS * BS_NPQ; | ||||
| const uint32_t Ash_len = BS_K * Ash_stride; | ||||
| const uint32_t Bsh_len = BS_CRS * Bsh_stride; | ||||
|  | ||||
| shared float Ash[Ash_len];  // K x CRS | ||||
| shared float Bsh[Bsh_len];  // CRS x NPQ | ||||
| shared SHMEM_TYPE Ash[Ash_len];  // K x CRS | ||||
| shared SHMEM_TYPE Bsh[Bsh_len];  // CRS x NPQ | ||||
|  | ||||
| // Threadtile sizes | ||||
| const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; | ||||
| @@ -110,10 +121,6 @@ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; | ||||
| const uint32_t NT_K   = BS_K / TS_K; | ||||
| const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; | ||||
|  | ||||
| float regA[TS_K]; | ||||
| float regB[TS_NPQ]; | ||||
| float regC[TS_K][TS_NPQ]; | ||||
|  | ||||
| /* | ||||
| Compute | ||||
| KxCRS @ CRSxNPQ = K x NPQ | ||||
| @@ -145,12 +152,36 @@ uint fastdiv(uint n, uint mp, uint L) { | ||||
|     return (msbs + n) >> L; | ||||
| } | ||||
|  | ||||
| #ifdef COOPMAT2 | ||||
| #define ACC_TYPE float16_t | ||||
|  | ||||
| ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) | ||||
| { | ||||
|     uint32_t K_idx   = B_idx_K * BS_K + r; | ||||
|     uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; | ||||
|     uint32_t N_idx   = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; | ||||
|     uint32_t OH_idx  = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; | ||||
|     uint32_t OW_idx  = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; | ||||
|     uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; | ||||
|     if (K_idx < K && NPQ_idx < NPQ) { | ||||
|         dst_data[dst_idx] = D_TYPE(elem); | ||||
|     } | ||||
|     return elem; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| void main() { | ||||
| #ifdef COOPMAT2 | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC; | ||||
|     matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0); | ||||
| #else | ||||
|     float regC[TS_K][TS_NPQ]; | ||||
|     for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { | ||||
|         for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { | ||||
|             regC[T_ly][T_lx] = 0.0; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|     /* Advance block in CRS dim */ | ||||
|     for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { | ||||
|         uint32_t CRS_idx_a; | ||||
| @@ -199,7 +230,7 @@ void main() { | ||||
|             if (K_idx >= K || CRS_idx_a >= CRS) { | ||||
|                 val = 0.0; | ||||
|             } | ||||
|             Ash[B_ly * Ash_stride + B_lx] = val; | ||||
|             Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); | ||||
|         } | ||||
|         /* Load input to B_block: (BS_CRS x BS_NPQ) */ | ||||
|         UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { | ||||
| @@ -244,11 +275,21 @@ void main() { | ||||
|             if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { | ||||
|                 val = 0.0; | ||||
|             } | ||||
|             Bsh[B_ly * Bsh_stride + B_lx] = val; | ||||
|             Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); | ||||
|         } | ||||
|         barrier(); | ||||
| #ifdef COOPMAT2 | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA; | ||||
|         coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB; | ||||
|  | ||||
|         coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); | ||||
|         coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); | ||||
|         matC = coopMatMulAdd(matA, matB, matC); | ||||
| #else | ||||
|         if (T_y * TS_K < K) { | ||||
|             UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { | ||||
|                 float regA[TS_K]; | ||||
|                 float regB[TS_NPQ]; | ||||
|                 for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { | ||||
|                     regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; | ||||
|                 } | ||||
| @@ -262,9 +303,13 @@ void main() { | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| #endif | ||||
|         barrier(); | ||||
|     } | ||||
|     /* Save C* */ | ||||
| #ifdef COOPMAT2 | ||||
|     coopMatPerElementNV(matC, matC, perElemOpStore); | ||||
| #else | ||||
|     if (T_y * TS_K < K) { | ||||
|         for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { | ||||
|             for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { | ||||
| @@ -280,4 +325,5 @@ void main() { | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| } | ||||
|   | ||||
| @@ -661,6 +661,9 @@ void process_shaders() { | ||||
|     string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); | ||||
|     string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); | ||||
|  | ||||
|     string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); | ||||
|     string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); | ||||
|  | ||||
|     string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); | ||||
|     string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz