mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	opencl: add multi and vision rope, gelu_quick and im2col (#12600)
				
					
				
			* opencl: add `im2col` * opencl: add `gelu_quick` * opencl: add mrope * opencl: add vision rope
This commit is contained in:
		| @@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS | ||||
|     ggml-opencl_transpose_16 | ||||
|     ggml-opencl_transpose_32 | ||||
|     ggml-opencl_transpose_32_16 | ||||
|     ggml-opencl_im2col | ||||
| ) | ||||
|  | ||||
| foreach (K ${GGML_OPENCL_KERNELS}) | ||||
|   | ||||
| @@ -224,12 +224,14 @@ struct ggml_backend_opencl_context { | ||||
|     cl_program program; | ||||
|     cl_program program_1; | ||||
|     cl_program program_2; | ||||
|     cl_program program_im2col; | ||||
|  | ||||
|     cl_kernel kernel_add, kernel_add_row; | ||||
|     cl_kernel kernel_mul, kernel_mul_row; | ||||
|     cl_kernel kernel_scale; | ||||
|     cl_kernel kernel_silu, kernel_silu_4; | ||||
|     cl_kernel kernel_gelu, kernel_gelu_4; | ||||
|     cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; | ||||
|     cl_kernel kernel_relu; | ||||
|     cl_kernel kernel_clamp; | ||||
|     cl_kernel kernel_norm; | ||||
| @@ -239,6 +241,7 @@ struct ggml_backend_opencl_context { | ||||
|     cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; | ||||
|     cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; | ||||
|     cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; | ||||
|     cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; | ||||
|     cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; | ||||
|     cl_kernel kernel_mul_mat_f32_f32; | ||||
|     cl_kernel kernel_mul_mat_f16_f16; | ||||
| @@ -252,6 +255,7 @@ struct ggml_backend_opencl_context { | ||||
|               kernel_mul_mat_q4_0_f32_flat_img_v0; | ||||
|     cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; | ||||
|     cl_kernel kernel_mul_mv_q6_K_f32; | ||||
|     cl_kernel kernel_im2col_f32, kernel_im2col_f16; | ||||
|  | ||||
| #ifdef GGML_OPENCL_USE_ADRENO_KERNELS | ||||
|     // Transpose kernels | ||||
| @@ -708,6 +712,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { | ||||
|     CL_CHECK((backend_ctx->kernel_silu_4             = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_gelu               = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_gelu_4             = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_gelu_quick         = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_gelu_quick_4       = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_relu               = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_clamp              = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_norm               = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); | ||||
| @@ -722,6 +728,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { | ||||
|     CL_CHECK((backend_ctx->kernel_rope_norm_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_rope_neox_f32      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_rope_neox_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_rope_multi_f32     = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_rope_multi_f16     = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_rope_vision_f32    = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_rope_vision_f16    = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_cpy_f16_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_cpy_f16_f32        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_cpy_f32_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); | ||||
| @@ -769,6 +779,19 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { | ||||
|  | ||||
|     CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle     = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); | ||||
|  | ||||
|     // im2col kernels | ||||
| #ifdef GGML_OPENCL_EMBED_KERNELS | ||||
|     const std::string kernel_src_im2col { | ||||
|         #include "ggml-opencl_im2col.cl.h" | ||||
|     }; | ||||
| #else | ||||
|     const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); | ||||
| #endif | ||||
|     backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); | ||||
|  | ||||
|     CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); | ||||
|     CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); | ||||
|  | ||||
|     // Kernels for Adreno | ||||
| #ifdef GGML_OPENCL_USE_ADRENO_KERNELS | ||||
| #ifdef GGML_OPENCL_EMBED_KERNELS | ||||
| @@ -1187,6 +1210,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te | ||||
|                 case GGML_UNARY_OP_GELU: | ||||
|                 case GGML_UNARY_OP_SILU: | ||||
|                 case GGML_UNARY_OP_RELU: | ||||
|                 case GGML_UNARY_OP_GELU_QUICK: | ||||
|                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; | ||||
|                 default: | ||||
|                     return false; | ||||
| @@ -1216,14 +1240,26 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te | ||||
|             return op->ne[3] == 1; | ||||
|         case GGML_OP_ROPE: { | ||||
|             const int mode = ((const int32_t *) op->op_params)[2]; | ||||
|             if (mode & GGML_ROPE_TYPE_MROPE) { | ||||
|             const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; | ||||
|             const bool is_vision = mode == GGML_ROPE_TYPE_VISION; | ||||
|             if (is_mrope && !is_vision) { | ||||
|                 if (op->src[0]->type == GGML_TYPE_F32 || | ||||
|                     op->src[0]->type == GGML_TYPE_F16) { | ||||
|                     return true; | ||||
|                 } | ||||
|                 return false; | ||||
|             } | ||||
|             if (mode & GGML_ROPE_TYPE_VISION) { | ||||
|             if (is_vision) { | ||||
|                 if (op->src[0]->type == GGML_TYPE_F32 || | ||||
|                     op->src[0]->type == GGML_TYPE_F16) { | ||||
|                     return true; | ||||
|                 } | ||||
|                 return false; | ||||
|             } | ||||
|             return true; | ||||
|         } | ||||
|         case GGML_OP_IM2COL: | ||||
|             return true; | ||||
|         default: | ||||
|             return false; | ||||
|     } | ||||
| @@ -2582,6 +2618,53 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const | ||||
| #endif | ||||
| } | ||||
|  | ||||
| static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(src0); | ||||
|     GGML_ASSERT(src0->extra); | ||||
|     GGML_ASSERT(dst); | ||||
|     GGML_ASSERT(dst->extra); | ||||
|  | ||||
|     UNUSED(src1); | ||||
|  | ||||
|     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||||
|     cl_command_queue queue = backend_ctx->queue; | ||||
|  | ||||
|     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; | ||||
|     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||||
|  | ||||
|     cl_ulong offset0 = extra0->offset + src0->view_offs; | ||||
|     cl_ulong offsetd = extrad->offset + dst->view_offs; | ||||
|  | ||||
|     cl_kernel kernel; | ||||
|  | ||||
|     int n = ggml_nelements(dst); | ||||
|  | ||||
|     if (n % 4 == 0) { | ||||
|         kernel = backend_ctx->kernel_gelu_quick_4; | ||||
|         n /= 4; | ||||
|     } else { | ||||
|         kernel = backend_ctx->kernel_gelu_quick; | ||||
|     } | ||||
|  | ||||
|     CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device)); | ||||
|     CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); | ||||
|     CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device)); | ||||
|     CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); | ||||
|  | ||||
|     size_t global_work_size[] = {(size_t)n, 1, 1}; | ||||
|     size_t local_work_size[] = {64, 1, 1}; | ||||
|  | ||||
| #ifdef GGML_OPENCL_PROFILING | ||||
|     cl_event evt; | ||||
|     clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); | ||||
|  | ||||
|     g_profiling_info.emplace_back(); | ||||
|     populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); | ||||
| #else | ||||
|     clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(src0); | ||||
|     GGML_ASSERT(src0->extra); | ||||
| @@ -3980,6 +4063,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | ||||
|     float attn_factor; | ||||
|     float beta_fast; | ||||
|     float beta_slow; | ||||
|     int32_t sections[4]; | ||||
|  | ||||
|     memcpy(&freq_base,   (int32_t *) dst->op_params + 5, sizeof(float)); | ||||
|     memcpy(&freq_scale,  (int32_t *) dst->op_params + 6, sizeof(float)); | ||||
| @@ -3987,23 +4071,23 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | ||||
|     memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); | ||||
|     memcpy(&beta_fast,   (int32_t *) dst->op_params + 9, sizeof(float)); | ||||
|     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float)); | ||||
|     memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int32_t)*4); | ||||
|  | ||||
|     const bool is_neox = mode & 2; | ||||
|     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; | ||||
|     const bool is_vision = mode == GGML_ROPE_TYPE_VISION; | ||||
|  | ||||
|     if (is_mrope) { | ||||
|         GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); | ||||
|     } | ||||
|  | ||||
|     if (is_vision) { | ||||
|         GGML_ASSERT(n_dims == ne00/2); | ||||
|     } | ||||
|  | ||||
|     cl_kernel kernel; | ||||
|  | ||||
|     if (!is_neox) { | ||||
|         switch (src0->type) { | ||||
|             case GGML_TYPE_F32: | ||||
|                 kernel = backend_ctx->kernel_rope_norm_f32; | ||||
|                 break; | ||||
|             case GGML_TYPE_F16: | ||||
|                 kernel = backend_ctx->kernel_rope_norm_f16; | ||||
|                 break; | ||||
|             default: | ||||
|                 GGML_ASSERT(false); | ||||
|         }; | ||||
|     } else { | ||||
|     if (is_neox) { | ||||
|         switch (src0->type) { | ||||
|             case GGML_TYPE_F32: | ||||
|                 kernel = backend_ctx->kernel_rope_neox_f32; | ||||
| @@ -4014,6 +4098,39 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | ||||
|             default: | ||||
|                 GGML_ASSERT(false); | ||||
|         }; | ||||
|     } else if (is_mrope && !is_vision) { | ||||
|         switch (src0->type) { | ||||
|             case GGML_TYPE_F32: | ||||
|                 kernel = backend_ctx->kernel_rope_multi_f32; | ||||
|                 break; | ||||
|             case GGML_TYPE_F16: | ||||
|                 kernel = backend_ctx->kernel_rope_multi_f16; | ||||
|                 break; | ||||
|             default: | ||||
|                 GGML_ASSERT(false); | ||||
|         }; | ||||
|     } else if (is_vision) { | ||||
|         switch (src0->type) { | ||||
|             case GGML_TYPE_F32: | ||||
|                 kernel = backend_ctx->kernel_rope_vision_f32; | ||||
|                 break; | ||||
|             case GGML_TYPE_F16: | ||||
|                 kernel = backend_ctx->kernel_rope_vision_f16; | ||||
|                 break; | ||||
|             default: | ||||
|                 GGML_ASSERT(false); | ||||
|         } | ||||
|     } else { | ||||
|         switch (src0->type) { | ||||
|             case GGML_TYPE_F32: | ||||
|                 kernel = backend_ctx->kernel_rope_norm_f32; | ||||
|                 break; | ||||
|             case GGML_TYPE_F16: | ||||
|                 kernel = backend_ctx->kernel_rope_norm_f16; | ||||
|                 break; | ||||
|             default: | ||||
|                 GGML_ASSERT(false); | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device)); | ||||
| @@ -4049,6 +4166,9 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | ||||
|     CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float),    &attn_factor)); | ||||
|     CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float),    &beta_fast)); | ||||
|     CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float),    &beta_slow)); | ||||
|     if (is_mrope || is_vision) { | ||||
|         CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions)); | ||||
|     } | ||||
|  | ||||
|     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; | ||||
|     size_t local_work_size[] = {(size_t)nth, 1, 1}; | ||||
| @@ -4064,6 +4184,98 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | ||||
| #endif | ||||
| } | ||||
|  | ||||
| static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(src0); | ||||
|     GGML_ASSERT(src1); | ||||
|     GGML_ASSERT(src1->extra); | ||||
|     GGML_ASSERT(dst); | ||||
|     GGML_ASSERT(dst->extra); | ||||
|  | ||||
|     // src0 - filter, src1 - input | ||||
|     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||||
|     cl_command_queue queue = backend_ctx->queue; | ||||
|  | ||||
|     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; | ||||
|     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||||
|  | ||||
|     cl_ulong offset1 = extra1->offset + src1->view_offs; | ||||
|     cl_ulong offsetd = extrad->offset + dst->view_offs; | ||||
|  | ||||
|     const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; | ||||
|     const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; | ||||
|     const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; | ||||
|     const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; | ||||
|     const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; | ||||
|     const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; | ||||
|  | ||||
|     const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; | ||||
|  | ||||
|     const cl_long IC = src1->ne[is_2D ? 2 : 1]; | ||||
|     const cl_long IH = is_2D ? src1->ne[1] : 1; | ||||
|     const cl_long IW =         src1->ne[0]; | ||||
|  | ||||
|     const cl_long KH = is_2D ? src0->ne[1] : 1; | ||||
|     const cl_long KW =         src0->ne[0]; | ||||
|  | ||||
|     const cl_long OH = is_2D ? dst->ne[2] : 1; | ||||
|     const cl_long OW =         dst->ne[1]; | ||||
|  | ||||
|     // nb is byte offset, src is type float32 | ||||
|     const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4; | ||||
|     const cl_long  batch        = src1->ne[is_2D ? 3 : 2]; | ||||
|     const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4; | ||||
|  | ||||
|     const cl_long pelements = OW*KW*KH; | ||||
|     const cl_long CHW       = IC*KH*KW; | ||||
|  | ||||
|     cl_kernel kernel; | ||||
|  | ||||
|     if(dst->type == GGML_TYPE_F16) { | ||||
|         kernel = backend_ctx->kernel_im2col_f16; | ||||
|     } else { | ||||
|         kernel = backend_ctx->kernel_im2col_f32; | ||||
|     } | ||||
|  | ||||
|     CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra1->data_device)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset1)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &extrad->data_device)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_ulong), &offsetd)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   4, sizeof(cl_ulong), &batch_offset)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   5, sizeof(cl_ulong), &delta_offset)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   6, sizeof(cl_long),  &IW)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   7, sizeof(cl_long),  &IH)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   8, sizeof(cl_long),  &IC)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_long),  &OW)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_long),  &OH)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_long),  &KW)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_long),  &KH)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  13, sizeof(cl_long),  &pelements)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  14, sizeof(cl_long),  &CHW)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  15, sizeof(int),      &s0)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  16, sizeof(int),      &s1)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  17, sizeof(int),      &p0)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  18, sizeof(int),      &p1)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  19, sizeof(int),      &d0)); | ||||
|     CL_CHECK(clSetKernelArg(kernel,  20, sizeof(int),      &d1)); | ||||
|  | ||||
|     const int num_blocks = (pelements + 256 - 1) / 256; | ||||
|     size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC}; | ||||
|     size_t local_work_size[] = {256, 1, 1}; | ||||
|  | ||||
| #ifdef GGML_OPENCL_PROFILING | ||||
|     cl_event evt; | ||||
|     CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); | ||||
|  | ||||
|     g_profiling_info.emplace_back(); | ||||
|     populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); | ||||
| #else | ||||
|     CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| //------------------------------------------------------------------------------ | ||||
| // Op offloading | ||||
| //------------------------------------------------------------------------------ | ||||
| @@ -4122,6 +4334,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor | ||||
|                     } | ||||
|                     func = ggml_cl_gelu; | ||||
|                     break; | ||||
|                 case GGML_UNARY_OP_GELU_QUICK: | ||||
|                     if (!any_on_device) { | ||||
|                         return false; | ||||
|                     } | ||||
|                     func = ggml_cl_gelu_quick; | ||||
|                     break; | ||||
|                 case GGML_UNARY_OP_SILU: | ||||
|                     if (!any_on_device) { | ||||
|                         return false; | ||||
| @@ -4194,6 +4412,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor | ||||
|             } | ||||
|             func = ggml_cl_rope; | ||||
|             break; | ||||
|         case GGML_OP_IM2COL: | ||||
|             if (!any_on_device) { | ||||
|                 return false; | ||||
|             } | ||||
|             func = ggml_cl_im2col; | ||||
|             break; | ||||
|         default: | ||||
|             return false; | ||||
|     } | ||||
|   | ||||
| @@ -404,6 +404,7 @@ kernel void kernel_scale( | ||||
| // gelu | ||||
| //------------------------------------------------------------------------------ | ||||
| #define GELU_COEF_A     0.044715f | ||||
| #define GELU_QUICK_COEF -1.702f | ||||
| #define SQRT_2_OVER_PI  0.79788456080286535587989211986876f | ||||
|  | ||||
| kernel void kernel_gelu( | ||||
| @@ -434,6 +435,32 @@ kernel void kernel_gelu_4( | ||||
|     dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); | ||||
| } | ||||
|  | ||||
| kernel void kernel_gelu_quick( | ||||
|     global float * src0, | ||||
|     ulong offset0, | ||||
|     global float * dst, | ||||
|     ulong offsetd | ||||
| ) { | ||||
|     src0 = (global float*)((global char*)src0 + offset0); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     float x = src0[get_global_id(0)]; | ||||
|     dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); | ||||
| } | ||||
|  | ||||
| kernel void kernel_gelu_quick_4( | ||||
|     global float4 * src0, | ||||
|     ulong offset0, | ||||
|     global float4 * dst, | ||||
|     ulong offsetd | ||||
| ) { | ||||
|     src0 = (global float4*)((global char*)src0 + offset0); | ||||
|     dst = (global float4*)((global char*)dst + offsetd); | ||||
|  | ||||
|     float4 x = src0[get_global_id(0)]; | ||||
|     dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); | ||||
| } | ||||
|  | ||||
| //------------------------------------------------------------------------------ | ||||
| // silu | ||||
| //------------------------------------------------------------------------------ | ||||
| @@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16( | ||||
|     } | ||||
| } | ||||
|  | ||||
| kernel void kernel_rope_multi_f32( | ||||
|         global void * src0, | ||||
|         ulong offset0, | ||||
|         global int * src1, | ||||
|         ulong offset1, | ||||
|         global float * src2, | ||||
|         ulong offset2, | ||||
|         global float * dst, | ||||
|         ulong offsetd, | ||||
|         int ne00, | ||||
|         int ne01, | ||||
|         int ne02, | ||||
|         int ne03, | ||||
|         ulong nb00, | ||||
|         ulong nb01, | ||||
|         ulong nb02, | ||||
|         ulong nb03, | ||||
|         int ne0, | ||||
|         int ne1, | ||||
|         int ne2, | ||||
|         int ne3, | ||||
|         ulong nb0, | ||||
|         ulong nb1, | ||||
|         ulong nb2, | ||||
|         ulong nb3, | ||||
|         int n_past, | ||||
|         int n_dims, | ||||
|         int n_ctx_orig, | ||||
|         float freq_base, | ||||
|         float freq_scale, | ||||
|         float ext_factor, | ||||
|         float attn_factor, | ||||
|         float beta_fast, | ||||
|         float beta_slow, | ||||
|         int4 sections | ||||
| ) { | ||||
|     src0 = (global void*)((global char*)src0 + offset0); | ||||
|     src1 = (global int*)((global char*)src1 + offset1); | ||||
|     src2 = (global float*)((global char*)src2 + offset2); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     int i3 = get_group_id(2); | ||||
|     int i2 = get_group_id(1); | ||||
|     int i1 = get_group_id(0); | ||||
|  | ||||
|     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||
|  | ||||
|     global int * pos = src1; | ||||
|  | ||||
|     const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; | ||||
|     const int sec_w = sections.s1 + sections.s0; | ||||
|  | ||||
|     float inv_ndims = -1.f/n_dims; | ||||
|  | ||||
|     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||
|         if (i0 < n_dims) { | ||||
|             int ic = i0/2; | ||||
|  | ||||
|             const int sector = (i0 / 2) % sect_dims; | ||||
|             float theta_base = 0.0f; | ||||
|  | ||||
|             if (sector < sections.s0) { | ||||
|                 theta_base = pos[i2]; | ||||
|             } | ||||
|             else if (sector >= sections.s0 && sector < sec_w) { | ||||
|                 theta_base = pos[i2 + ne2 * 1]; | ||||
|             } | ||||
|             else if (sector >= sec_w && sector < sec_w + sections.s2) { | ||||
|                 theta_base = pos[i2 + ne2 * 2]; | ||||
|             } | ||||
|             else if (sector >= sec_w + sections.s2) { | ||||
|                 theta_base = pos[i2 + ne2 * 3]; | ||||
|             } | ||||
|  | ||||
|             const float theta = theta_base * pow(freq_base, inv_ndims*i0); | ||||
|  | ||||
|             const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||
|  | ||||
|             float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||
|  | ||||
|             global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||
|             global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||
|  | ||||
|             const float x0 = src[0]; | ||||
|             const float x1 = src[n_dims/2]; | ||||
|  | ||||
|             dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||
|             dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||
|         } else { | ||||
|             global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|             global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | ||||
|             dst_data[0] = src[0]; | ||||
|             dst_data[1] = src[1]; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| kernel void kernel_rope_multi_f16( | ||||
|         global void * src0, | ||||
|         ulong offset0, | ||||
|         global int * src1, | ||||
|         ulong offset1, | ||||
|         global float * src2, | ||||
|         ulong offset2, | ||||
|         global half * dst, | ||||
|         ulong offsetd, | ||||
|         int ne00, | ||||
|         int ne01, | ||||
|         int ne02, | ||||
|         int ne03, | ||||
|         ulong nb00, | ||||
|         ulong nb01, | ||||
|         ulong nb02, | ||||
|         ulong nb03, | ||||
|         int ne0, | ||||
|         int ne1, | ||||
|         int ne2, | ||||
|         int ne3, | ||||
|         ulong nb0, | ||||
|         ulong nb1, | ||||
|         ulong nb2, | ||||
|         ulong nb3, | ||||
|         int n_past, | ||||
|         int n_dims, | ||||
|         int n_ctx_orig, | ||||
|         float freq_base, | ||||
|         float freq_scale, | ||||
|         float ext_factor, | ||||
|         float attn_factor, | ||||
|         float beta_fast, | ||||
|         float beta_slow, | ||||
|         int4 sections | ||||
| ) { | ||||
|     src0 = (global void*)((global char*)src0 + offset0); | ||||
|     src1 = (global int*)((global char*)src1 + offset1); | ||||
|     src2 = (global float*)((global char*)src2 + offset2); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     int i3 = get_group_id(2); | ||||
|     int i2 = get_group_id(1); | ||||
|     int i1 = get_group_id(0); | ||||
|  | ||||
|     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||
|  | ||||
|     global int * pos = src1; | ||||
|  | ||||
|     const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; | ||||
|     const int sec_w = sections.s1 + sections.s0; | ||||
|  | ||||
|     float inv_ndims = -1.f/n_dims; | ||||
|  | ||||
|     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||
|         if (i0 < n_dims) { | ||||
|             int ic = i0/2; | ||||
|  | ||||
|             const int sector = (i0 / 2) % sect_dims; | ||||
|             float theta_base = 0.0f; | ||||
|  | ||||
|             if (sector < sections.s0) { | ||||
|                 theta_base = pos[i2]; | ||||
|             } | ||||
|             else if (sector >= sections.s0 && sector < sec_w) { | ||||
|                 theta_base = pos[i2 + ne2 * 1]; | ||||
|             } | ||||
|             else if (sector >= sec_w && sector < sec_w + sections.s2) { | ||||
|                 theta_base = pos[i2 + ne2 * 2]; | ||||
|             } | ||||
|             else if (sector >= sec_w + sections.s2) { | ||||
|                 theta_base = pos[i2 + ne2 * 3]; | ||||
|             } | ||||
|  | ||||
|             const float theta = theta_base * pow(freq_base, inv_ndims*i0); | ||||
|  | ||||
|             const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||
|  | ||||
|             float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||
|  | ||||
|             global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||
|             global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||
|  | ||||
|             const float x0 = src[0]; | ||||
|             const float x1 = src[n_dims/2]; | ||||
|  | ||||
|             dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||
|             dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||
|         } else { | ||||
|             global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|             global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | ||||
|             dst_data[0] = src[0]; | ||||
|             dst_data[1] = src[1]; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| kernel void kernel_rope_vision_f32( | ||||
|         global void * src0, | ||||
|         ulong offset0, | ||||
|         global int * src1, | ||||
|         ulong offset1, | ||||
|         global float * src2, | ||||
|         ulong offset2, | ||||
|         global float * dst, | ||||
|         ulong offsetd, | ||||
|         int ne00, | ||||
|         int ne01, | ||||
|         int ne02, | ||||
|         int ne03, | ||||
|         ulong nb00, | ||||
|         ulong nb01, | ||||
|         ulong nb02, | ||||
|         ulong nb03, | ||||
|         int ne0, | ||||
|         int ne1, | ||||
|         int ne2, | ||||
|         int ne3, | ||||
|         ulong nb0, | ||||
|         ulong nb1, | ||||
|         ulong nb2, | ||||
|         ulong nb3, | ||||
|         int n_past, | ||||
|         int n_dims, | ||||
|         int n_ctx_orig, | ||||
|         float freq_base, | ||||
|         float freq_scale, | ||||
|         float ext_factor, | ||||
|         float attn_factor, | ||||
|         float beta_fast, | ||||
|         float beta_slow, | ||||
|         int4 sections | ||||
| ) { | ||||
|     src0 = (global void*)((global char*)src0 + offset0); | ||||
|     src1 = (global int*)((global char*)src1 + offset1); | ||||
|     src2 = (global float*)((global char*)src2 + offset2); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     int i3 = get_group_id(2); | ||||
|     int i2 = get_group_id(1); | ||||
|     int i1 = get_group_id(0); | ||||
|  | ||||
|     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||
|  | ||||
|     global int * pos = src1; | ||||
|  | ||||
|     const int sect_dims = sections.s0 + sections.s1; | ||||
|     const int sec_w = sections.s1 + sections.s0; | ||||
|  | ||||
|     float inv_ndims = -1.f/n_dims; | ||||
|  | ||||
|     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||
|         int ic = i0/2; | ||||
|  | ||||
|         const int sector = (i0/2) % sect_dims; | ||||
|         float theta_base = 0.0f; | ||||
|  | ||||
|         if (sector < sections.s0) { | ||||
|             const int p = sector; | ||||
|             theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); | ||||
|         } else if (sector >= sections.s0 && sector < sec_w) { | ||||
|             const int p = sector - sections.s0; | ||||
|             theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); | ||||
|         } | ||||
|  | ||||
|         const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||
|  | ||||
|         float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||
|  | ||||
|         global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||
|         global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||
|  | ||||
|         const float x0 = src[0]; | ||||
|         const float x1 = src[n_dims]; | ||||
|  | ||||
|         dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||
|         dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||
|     } | ||||
| } | ||||
|  | ||||
| kernel void kernel_rope_vision_f16( | ||||
|         global void * src0, | ||||
|         ulong offset0, | ||||
|         global int * src1, | ||||
|         ulong offset1, | ||||
|         global float * src2, | ||||
|         ulong offset2, | ||||
|         global half * dst, | ||||
|         ulong offsetd, | ||||
|         int ne00, | ||||
|         int ne01, | ||||
|         int ne02, | ||||
|         int ne03, | ||||
|         ulong nb00, | ||||
|         ulong nb01, | ||||
|         ulong nb02, | ||||
|         ulong nb03, | ||||
|         int ne0, | ||||
|         int ne1, | ||||
|         int ne2, | ||||
|         int ne3, | ||||
|         ulong nb0, | ||||
|         ulong nb1, | ||||
|         ulong nb2, | ||||
|         ulong nb3, | ||||
|         int n_past, | ||||
|         int n_dims, | ||||
|         int n_ctx_orig, | ||||
|         float freq_base, | ||||
|         float freq_scale, | ||||
|         float ext_factor, | ||||
|         float attn_factor, | ||||
|         float beta_fast, | ||||
|         float beta_slow, | ||||
|         int4 sections | ||||
| ) { | ||||
|     src0 = (global void*)((global char*)src0 + offset0); | ||||
|     src1 = (global int*)((global char*)src1 + offset1); | ||||
|     src2 = (global float*)((global char*)src2 + offset2); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     int i3 = get_group_id(2); | ||||
|     int i2 = get_group_id(1); | ||||
|     int i1 = get_group_id(0); | ||||
|  | ||||
|     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||
|  | ||||
|     global int * pos = src1; | ||||
|  | ||||
|     const int sect_dims = sections.s0 + sections.s1; | ||||
|     const int sec_w = sections.s1 + sections.s0; | ||||
|  | ||||
|     float inv_ndims = -1.f/n_dims; | ||||
|  | ||||
|     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||
|         int ic = i0/2; | ||||
|  | ||||
|         const int sector = (i0/2) % sect_dims; | ||||
|         float theta_base = 0.0f; | ||||
|  | ||||
|         if (sector < sections.s0) { | ||||
|             const int p = sector; | ||||
|             theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); | ||||
|         } else if (sector >= sections.s0 && sector < sec_w) { | ||||
|             const int p = sector - sections.s0; | ||||
|             theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); | ||||
|         } | ||||
|  | ||||
|         const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||
|  | ||||
|         float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||
|  | ||||
|         global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||
|         global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||
|  | ||||
|         const float x0 = src[0]; | ||||
|         const float x1 = src[n_dims]; | ||||
|  | ||||
|         dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||
|         dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||
|     } | ||||
| } | ||||
|  | ||||
| //------------------------------------------------------------------------------ | ||||
| // cpy | ||||
| //------------------------------------------------------------------------------ | ||||
|   | ||||
							
								
								
									
										146
									
								
								ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,146 @@ | ||||
| #ifdef cl_khr_fp16 | ||||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||||
| #elif defined(cl_amd_fp16) | ||||
| #pragma OPENCL EXTENSION cl_amd_fp16 : enable | ||||
| #else | ||||
| #error "Half precision floating point not supportedby OpenCL implementation on your device." | ||||
| #endif | ||||
|  | ||||
| #ifdef cl_khr_subgroups | ||||
| #pragma OPENCL EXTENSION cl_khr_subgroups : enable | ||||
| #elif defined(cl_intel_subgroups) | ||||
| #pragma OPENCL EXTENSION cl_intel_subgroups : enable | ||||
| #else | ||||
| #error "Subgroup not supported on your device." | ||||
| #endif | ||||
|  | ||||
| #ifdef cl_intel_required_subgroup_size | ||||
| // Always use subgroup size of 32 on Intel. | ||||
| #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable | ||||
| #define INTEL_GPU 1 | ||||
| #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) | ||||
| #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) | ||||
| #elif defined(cl_qcom_reqd_sub_group_size) | ||||
| // Always use subgroups size of 64 on Adreno. | ||||
| #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable | ||||
| #define ADRENO_GPU 1 | ||||
| #define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half"))) | ||||
| #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) | ||||
| #else | ||||
| // TODO: do not know how to choose subgroup size on other GPUs. | ||||
| #error "Selecting subgroup size is not supported on your device." | ||||
| #endif | ||||
|  | ||||
| kernel void kernel_im2col_f32( | ||||
|         global float * src1, | ||||
|         ulong offset1, | ||||
|         global float * dst, | ||||
|         ulong offsetd, | ||||
|         ulong batch_offset, | ||||
|         ulong delta_offset, | ||||
|         long IW, | ||||
|         long IH, | ||||
|         long IC, | ||||
|         long OW, | ||||
|         long OH, | ||||
|         long KW, | ||||
|         long KH, | ||||
|         long pelements, | ||||
|         long CHW, | ||||
|         int  s0, | ||||
|         int  s1, | ||||
|         int  p0, | ||||
|         int  p1, | ||||
|         int  d0, | ||||
|         int  d1 | ||||
| ) { | ||||
|     // threadIdx.x + blockIdx.x * blockDim.x | ||||
|     long i = get_global_id(0); | ||||
|     if (i >= pelements) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     src1 = (global float*)((global char*)src1 + offset1); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     long  ksize = OW * (KH > 1 ? KW : 1); | ||||
|     long  kx = i / ksize; | ||||
|     long  kd = kx * ksize; | ||||
|     long  ky = (i - kd) / OW; | ||||
|     long  ix = i % OW; | ||||
|  | ||||
|     long  oh = get_group_id(1); | ||||
|     long  batch = get_group_id(2) / IC; | ||||
|     long  ic = get_group_id(2) % IC; | ||||
|  | ||||
|     long iiw = ix * s0 + kx * d0 - p0; | ||||
|     long iih = oh * s1 + ky * d1 - p1; | ||||
|  | ||||
|     long offset_dst = | ||||
|         ((batch * OH + oh) * OW + ix) * CHW + | ||||
|         (ic * (KW * KH) + ky * KW + kx); | ||||
|  | ||||
|     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { | ||||
|         dst[offset_dst] = 0.0f; | ||||
|     } else { | ||||
|         long offset_src = ic * delta_offset + batch * batch_offset; | ||||
|         dst[offset_dst] = src1[offset_src + iih * IW + iiw]; | ||||
|     } | ||||
| } | ||||
|  | ||||
| kernel void kernel_im2col_f16( | ||||
|         global float * src1, | ||||
|         ulong offset1, | ||||
|         global half  * dst, | ||||
|         ulong offsetd, | ||||
|         ulong batch_offset, | ||||
|         ulong delta_offset, | ||||
|         long IW, | ||||
|         long IH, | ||||
|         long IC, | ||||
|         long OW, | ||||
|         long OH, | ||||
|         long KW, | ||||
|         long KH, | ||||
|         long pelements, | ||||
|         long CHW, | ||||
|         int  s0, | ||||
|         int  s1, | ||||
|         int  p0, | ||||
|         int  p1, | ||||
|         int  d0, | ||||
|         int  d1 | ||||
| ) { | ||||
|     long i = get_global_id(0); | ||||
|  | ||||
|     if (i >= pelements) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     src1 = (global float*)((global char*)src1 + offset1); | ||||
|     dst = (global half*)((global char*)dst + offsetd); | ||||
|  | ||||
|     long  ksize = OW * (KH > 1 ? KW : 1); | ||||
|     long  kx = i / ksize; | ||||
|     long  kd = kx * ksize; | ||||
|     long  ky = (i - kd) / OW; | ||||
|     long  ix = i % OW; | ||||
|  | ||||
|     long  oh = get_group_id(1); | ||||
|     long  batch = get_group_id(2) / IC; | ||||
|     long  ic = get_group_id(2) % IC; | ||||
|  | ||||
|     long iiw = ix * s0 + kx * d0 - p0; | ||||
|     long iih = oh * s1 + ky * d1 - p1; | ||||
|  | ||||
|     long offset_dst = | ||||
|         ((batch * OH + oh) * OW + ix) * CHW + | ||||
|         (ic * (KW * KH) + ky * KW + kx); | ||||
|  | ||||
|     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { | ||||
|         dst[offset_dst] = 0.0f; | ||||
|     } else { | ||||
|         long offset_src = ic * delta_offset + batch * batch_offset; | ||||
|         dst[offset_dst] = src1[offset_src + iih * IW + iiw]; | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 lhez
					lhez