mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	vulkan: support im2col_3d (#15795)
This commit is contained in:
		| @@ -554,6 +554,7 @@ struct vk_device_struct { | ||||
|     vk_pipeline pipeline_argmax_f32; | ||||
|     vk_pipeline pipeline_count_equal_i32; | ||||
|     vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; | ||||
|     vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; | ||||
|     vk_pipeline pipeline_timestep_embedding_f32; | ||||
|     vk_pipeline pipeline_conv_transpose_1d_f32; | ||||
|     vk_pipeline pipeline_pool2d_f32; | ||||
| @@ -982,6 +983,37 @@ struct vk_op_im2col_push_constants { | ||||
|     int32_t d0; int32_t d1; | ||||
| }; | ||||
|  | ||||
| struct vk_op_im2col_3d_push_constants { | ||||
|     uint32_t nb10; | ||||
|     uint32_t nb11; | ||||
|     uint32_t nb12; | ||||
|     uint32_t nb13; | ||||
|     uint32_t s0; | ||||
|     uint32_t s1; | ||||
|     uint32_t s2; | ||||
|     uint32_t p0; | ||||
|     uint32_t p1; | ||||
|     uint32_t p2; | ||||
|     uint32_t d0; | ||||
|     uint32_t d1; | ||||
|     uint32_t d2; | ||||
|     uint32_t IW; | ||||
|     uint32_t IH; | ||||
|     uint32_t ID; | ||||
|     uint32_t IC; | ||||
|     uint32_t KW; | ||||
|     uint32_t OH; | ||||
|     uint32_t KD_KH_KW; | ||||
|     uint32_t KH_KW; | ||||
|     uint32_t IC_KD_KH_KW; | ||||
|     uint32_t N_OD_OH; | ||||
|     uint32_t OD_OH; | ||||
|     uint32_t OD_OH_OW_IC_KD_KH_KW; | ||||
|     uint32_t OH_OW_IC_KD_KH_KW; | ||||
|     uint32_t OW_IC_KD_KH_KW; | ||||
|     uint32_t misalign_offsets; | ||||
| }; | ||||
|  | ||||
| struct vk_op_timestep_embedding_push_constants { | ||||
|     uint32_t nb1; | ||||
|     uint32_t dim; | ||||
| @@ -3380,10 +3412,13 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); | ||||
|     if (device->float_controls_rte_fp16) { | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); | ||||
|     } else { | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); | ||||
|     } | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); | ||||
| @@ -7717,6 +7752,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | ||||
|             return ctx->device->pipeline_im2col_f32_f16; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_IM2COL_3D: | ||||
|         if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_im2col_3d_f32; | ||||
|         } | ||||
|         if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { | ||||
|             return ctx->device->pipeline_im2col_3d_f32_f16; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_timestep_embedding_f32; | ||||
| @@ -7832,6 +7875,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { | ||||
|     case GGML_OP_RMS_NORM: | ||||
|     case GGML_OP_CONV_2D_DW: | ||||
|     case GGML_OP_IM2COL: | ||||
|     case GGML_OP_IM2COL_3D: | ||||
|     case GGML_OP_SET_ROWS: | ||||
|     case GGML_OP_SUM: | ||||
|     case GGML_OP_SUM_ROWS: | ||||
| @@ -7890,6 +7934,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk | ||||
|     GGML_UNUSED(src2); | ||||
| } | ||||
|  | ||||
| template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { | ||||
|     const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); | ||||
|     const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); | ||||
|  | ||||
|     p.misalign_offsets = (a_offset << 16) | d_offset; | ||||
|  | ||||
|     GGML_UNUSED(src0); | ||||
|     GGML_UNUSED(src2); | ||||
| } | ||||
|  | ||||
| template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { | ||||
|     const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); | ||||
|     const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); | ||||
| @@ -8130,6 +8184,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co | ||||
|  | ||||
|             elements = { OW * KW * KH, OH, batch * IC }; | ||||
|         } break; | ||||
|     case GGML_OP_IM2COL_3D: | ||||
|         { | ||||
|             const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; | ||||
|  | ||||
|             const uint32_t N  = ne13 / IC; | ||||
|  | ||||
|             const uint32_t KD = ne02; | ||||
|             const uint32_t KH = ne01; | ||||
|             const uint32_t KW = ne00; | ||||
|  | ||||
|             const uint32_t OD = ned3 / N; | ||||
|             const uint32_t OH = ned2; | ||||
|             const uint32_t OW = ned1; | ||||
|  | ||||
|             const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; | ||||
|             const uint32_t N_OD_OH = N*OD*OH; | ||||
|  | ||||
|             elements = { IC_KD_KH_KW, OW, N_OD_OH }; | ||||
|             elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); | ||||
|         } break; | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|         { | ||||
|             const uint32_t dim = dst->op_params[0]; | ||||
| @@ -8286,7 +8360,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co | ||||
|         } | ||||
|  | ||||
|         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); | ||||
|     } else if (op == GGML_OP_IM2COL) { | ||||
|     } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { | ||||
|         // im2col uses only src1 and dst buffers | ||||
|         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); | ||||
|     } else if (op == GGML_OP_COUNT_EQUAL) { | ||||
| @@ -9147,6 +9221,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co | ||||
|     }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { | ||||
|     GGML_TENSOR_BINARY_OP_LOCALS | ||||
|  | ||||
|     const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; | ||||
|     const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; | ||||
|     const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; | ||||
|     const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; | ||||
|     const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; | ||||
|     const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; | ||||
|     const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; | ||||
|     const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; | ||||
|     const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; | ||||
|     const int32_t IC = ((const int32_t *)(dst->op_params))[9]; | ||||
|  | ||||
|     const int64_t N  = ne13 / IC; | ||||
|     const int64_t ID = ne12; | ||||
|     const int64_t IH = ne11; | ||||
|     const int64_t IW = ne10; | ||||
|  | ||||
|     const int64_t KD = ne02; | ||||
|     const int64_t KH = ne01; | ||||
|     const int64_t KW = ne00; | ||||
|  | ||||
|     const int64_t OD = ne3 / N; | ||||
|     const int64_t OH = ne2; | ||||
|     const int64_t OW = ne1; | ||||
|  | ||||
|     vk_op_im2col_3d_push_constants pc {}; | ||||
|  | ||||
|     pc.nb10 = nb10 / ggml_type_size(src1->type); | ||||
|     pc.nb11 = nb11 / ggml_type_size(src1->type); | ||||
|     pc.nb12 = nb12 / ggml_type_size(src1->type); | ||||
|     pc.nb13 = nb13 / ggml_type_size(src1->type); | ||||
|     pc.s0 = s0; | ||||
|     pc.s1 = s1; | ||||
|     pc.s2 = s2; | ||||
|     pc.p0 = p0; | ||||
|     pc.p1 = p1; | ||||
|     pc.p2 = p2; | ||||
|     pc.d0 = d0; | ||||
|     pc.d1 = d1; | ||||
|     pc.d2 = d2; | ||||
|     pc.IW = IW; | ||||
|     pc.IH = IH; | ||||
|     pc.ID = ID; | ||||
|     pc.IC = IC; | ||||
|     pc.KW = KW; | ||||
|     pc.OH = OH; | ||||
|     pc.KD_KH_KW = KD*KH*KW; | ||||
|     pc.KH_KW = KH*KW; | ||||
|     pc.IC_KD_KH_KW = IC*KD*KH*KW; | ||||
|     pc.N_OD_OH = N*OD*OH; | ||||
|     pc.OD_OH = OD*OH; | ||||
|     pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; | ||||
|     pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; | ||||
|     pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; | ||||
|  | ||||
|     ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | ||||
|     const uint32_t dim = dst->op_params[0]; | ||||
|     const uint32_t max_period = dst->op_params[1]; | ||||
| @@ -10352,6 +10486,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr | ||||
|     case GGML_OP_ARGMAX: | ||||
|     case GGML_OP_COUNT_EQUAL: | ||||
|     case GGML_OP_IM2COL: | ||||
|     case GGML_OP_IM2COL_3D: | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|     case GGML_OP_CONV_TRANSPOSE_1D: | ||||
|     case GGML_OP_POOL_2D: | ||||
| @@ -10422,6 +10557,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr | ||||
|         case GGML_OP_ARGMAX: | ||||
|         case GGML_OP_COUNT_EQUAL: | ||||
|         case GGML_OP_IM2COL: | ||||
|         case GGML_OP_IM2COL_3D: | ||||
|         case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|         case GGML_OP_CONV_TRANSPOSE_1D: | ||||
|         case GGML_OP_POOL_2D: | ||||
| @@ -10717,6 +10853,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr | ||||
|     case GGML_OP_IM2COL: | ||||
|         ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_IM2COL_3D: | ||||
|         ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|         ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); | ||||
| @@ -10868,6 +11008,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * | ||||
|     case GGML_OP_ARGMAX: | ||||
|     case GGML_OP_COUNT_EQUAL: | ||||
|     case GGML_OP_IM2COL: | ||||
|     case GGML_OP_IM2COL_3D: | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|     case GGML_OP_CONV_TRANSPOSE_1D: | ||||
|     case GGML_OP_POOL_2D: | ||||
| @@ -12150,6 +12291,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|         case GGML_OP_ARGMAX: | ||||
|         case GGML_OP_COUNT_EQUAL: | ||||
|         case GGML_OP_IM2COL: | ||||
|         case GGML_OP_IM2COL_3D: | ||||
|         case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|         case GGML_OP_CONV_2D_DW: | ||||
|         case GGML_OP_POOL_2D: | ||||
| @@ -12725,6 +12867,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * | ||||
|  | ||||
|         const bool is_2D = tensor->op_params[6] == 1; | ||||
|         tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); | ||||
|     } else if (tensor->op == GGML_OP_IM2COL_3D) { | ||||
|         const int32_t s0 = tensor->op_params[0]; | ||||
|         const int32_t s1 = tensor->op_params[1]; | ||||
|         const int32_t s1 = tensor->op_params[2]; | ||||
|         const int32_t p0 = tensor->op_params[3]; | ||||
|         const int32_t p1 = tensor->op_params[4]; | ||||
|         const int32_t p1 = tensor->op_params[5]; | ||||
|         const int32_t d0 = tensor->op_params[6]; | ||||
|         const int32_t d1 = tensor->op_params[7]; | ||||
|         const int32_t d1 = tensor->op_params[8]; | ||||
|         const int32_t IC = tensor->op_params[9]; | ||||
|  | ||||
|         tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); | ||||
|     } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { | ||||
|         const int32_t dim = tensor->op_params[0]; | ||||
|         const int32_t max_period = tensor->op_params[1]; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz