mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	opencl: add multi and vision rope, gelu_quick and im2col (#12600)
				
					
				
			* opencl: add `im2col` * opencl: add `gelu_quick` * opencl: add mrope * opencl: add vision rope
This commit is contained in:
		| @@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS | |||||||
|     ggml-opencl_transpose_16 |     ggml-opencl_transpose_16 | ||||||
|     ggml-opencl_transpose_32 |     ggml-opencl_transpose_32 | ||||||
|     ggml-opencl_transpose_32_16 |     ggml-opencl_transpose_32_16 | ||||||
|  |     ggml-opencl_im2col | ||||||
| ) | ) | ||||||
|  |  | ||||||
| foreach (K ${GGML_OPENCL_KERNELS}) | foreach (K ${GGML_OPENCL_KERNELS}) | ||||||
|   | |||||||
| @@ -224,12 +224,14 @@ struct ggml_backend_opencl_context { | |||||||
|     cl_program program; |     cl_program program; | ||||||
|     cl_program program_1; |     cl_program program_1; | ||||||
|     cl_program program_2; |     cl_program program_2; | ||||||
|  |     cl_program program_im2col; | ||||||
|  |  | ||||||
|     cl_kernel kernel_add, kernel_add_row; |     cl_kernel kernel_add, kernel_add_row; | ||||||
|     cl_kernel kernel_mul, kernel_mul_row; |     cl_kernel kernel_mul, kernel_mul_row; | ||||||
|     cl_kernel kernel_scale; |     cl_kernel kernel_scale; | ||||||
|     cl_kernel kernel_silu, kernel_silu_4; |     cl_kernel kernel_silu, kernel_silu_4; | ||||||
|     cl_kernel kernel_gelu, kernel_gelu_4; |     cl_kernel kernel_gelu, kernel_gelu_4; | ||||||
|  |     cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; | ||||||
|     cl_kernel kernel_relu; |     cl_kernel kernel_relu; | ||||||
|     cl_kernel kernel_clamp; |     cl_kernel kernel_clamp; | ||||||
|     cl_kernel kernel_norm; |     cl_kernel kernel_norm; | ||||||
| @@ -239,6 +241,7 @@ struct ggml_backend_opencl_context { | |||||||
|     cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; |     cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; | ||||||
|     cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; |     cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; | ||||||
|     cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; |     cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; | ||||||
|  |     cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; | ||||||
|     cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; |     cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; | ||||||
|     cl_kernel kernel_mul_mat_f32_f32; |     cl_kernel kernel_mul_mat_f32_f32; | ||||||
|     cl_kernel kernel_mul_mat_f16_f16; |     cl_kernel kernel_mul_mat_f16_f16; | ||||||
| @@ -252,6 +255,7 @@ struct ggml_backend_opencl_context { | |||||||
|               kernel_mul_mat_q4_0_f32_flat_img_v0; |               kernel_mul_mat_q4_0_f32_flat_img_v0; | ||||||
|     cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; |     cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; | ||||||
|     cl_kernel kernel_mul_mv_q6_K_f32; |     cl_kernel kernel_mul_mv_q6_K_f32; | ||||||
|  |     cl_kernel kernel_im2col_f32, kernel_im2col_f16; | ||||||
|  |  | ||||||
| #ifdef GGML_OPENCL_USE_ADRENO_KERNELS | #ifdef GGML_OPENCL_USE_ADRENO_KERNELS | ||||||
|     // Transpose kernels |     // Transpose kernels | ||||||
| @@ -708,6 +712,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { | |||||||
|     CL_CHECK((backend_ctx->kernel_silu_4             = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); |     CL_CHECK((backend_ctx->kernel_silu_4             = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_gelu               = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); |     CL_CHECK((backend_ctx->kernel_gelu               = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_gelu_4             = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); |     CL_CHECK((backend_ctx->kernel_gelu_4             = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); | ||||||
|  |     CL_CHECK((backend_ctx->kernel_gelu_quick         = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); | ||||||
|  |     CL_CHECK((backend_ctx->kernel_gelu_quick_4       = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_relu               = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); |     CL_CHECK((backend_ctx->kernel_relu               = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_clamp              = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); |     CL_CHECK((backend_ctx->kernel_clamp              = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_norm               = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); |     CL_CHECK((backend_ctx->kernel_norm               = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); | ||||||
| @@ -722,6 +728,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { | |||||||
|     CL_CHECK((backend_ctx->kernel_rope_norm_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); |     CL_CHECK((backend_ctx->kernel_rope_norm_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_rope_neox_f32      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); |     CL_CHECK((backend_ctx->kernel_rope_neox_f32      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_rope_neox_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); |     CL_CHECK((backend_ctx->kernel_rope_neox_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); | ||||||
|  |     CL_CHECK((backend_ctx->kernel_rope_multi_f32     = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); | ||||||
|  |     CL_CHECK((backend_ctx->kernel_rope_multi_f16     = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); | ||||||
|  |     CL_CHECK((backend_ctx->kernel_rope_vision_f32    = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); | ||||||
|  |     CL_CHECK((backend_ctx->kernel_rope_vision_f16    = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_cpy_f16_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); |     CL_CHECK((backend_ctx->kernel_cpy_f16_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_cpy_f16_f32        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); |     CL_CHECK((backend_ctx->kernel_cpy_f16_f32        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); | ||||||
|     CL_CHECK((backend_ctx->kernel_cpy_f32_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); |     CL_CHECK((backend_ctx->kernel_cpy_f32_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); | ||||||
| @@ -769,6 +779,19 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { | |||||||
|  |  | ||||||
|     CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle     = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); |     CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle     = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); | ||||||
|  |  | ||||||
|  |     // im2col kernels | ||||||
|  | #ifdef GGML_OPENCL_EMBED_KERNELS | ||||||
|  |     const std::string kernel_src_im2col { | ||||||
|  |         #include "ggml-opencl_im2col.cl.h" | ||||||
|  |     }; | ||||||
|  | #else | ||||||
|  |     const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); | ||||||
|  | #endif | ||||||
|  |     backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); | ||||||
|  |  | ||||||
|  |     CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); | ||||||
|  |     CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); | ||||||
|  |  | ||||||
|     // Kernels for Adreno |     // Kernels for Adreno | ||||||
| #ifdef GGML_OPENCL_USE_ADRENO_KERNELS | #ifdef GGML_OPENCL_USE_ADRENO_KERNELS | ||||||
| #ifdef GGML_OPENCL_EMBED_KERNELS | #ifdef GGML_OPENCL_EMBED_KERNELS | ||||||
| @@ -1187,6 +1210,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te | |||||||
|                 case GGML_UNARY_OP_GELU: |                 case GGML_UNARY_OP_GELU: | ||||||
|                 case GGML_UNARY_OP_SILU: |                 case GGML_UNARY_OP_SILU: | ||||||
|                 case GGML_UNARY_OP_RELU: |                 case GGML_UNARY_OP_RELU: | ||||||
|  |                 case GGML_UNARY_OP_GELU_QUICK: | ||||||
|                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; |                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; | ||||||
|                 default: |                 default: | ||||||
|                     return false; |                     return false; | ||||||
| @@ -1216,14 +1240,26 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te | |||||||
|             return op->ne[3] == 1; |             return op->ne[3] == 1; | ||||||
|         case GGML_OP_ROPE: { |         case GGML_OP_ROPE: { | ||||||
|             const int mode = ((const int32_t *) op->op_params)[2]; |             const int mode = ((const int32_t *) op->op_params)[2]; | ||||||
|             if (mode & GGML_ROPE_TYPE_MROPE) { |             const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; | ||||||
|  |             const bool is_vision = mode == GGML_ROPE_TYPE_VISION; | ||||||
|  |             if (is_mrope && !is_vision) { | ||||||
|  |                 if (op->src[0]->type == GGML_TYPE_F32 || | ||||||
|  |                     op->src[0]->type == GGML_TYPE_F16) { | ||||||
|  |                     return true; | ||||||
|  |                 } | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|             if (mode & GGML_ROPE_TYPE_VISION) { |             if (is_vision) { | ||||||
|  |                 if (op->src[0]->type == GGML_TYPE_F32 || | ||||||
|  |                     op->src[0]->type == GGML_TYPE_F16) { | ||||||
|  |                     return true; | ||||||
|  |                 } | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|             return true; |             return true; | ||||||
|         } |         } | ||||||
|  |         case GGML_OP_IM2COL: | ||||||
|  |             return true; | ||||||
|         default: |         default: | ||||||
|             return false; |             return false; | ||||||
|     } |     } | ||||||
| @@ -2582,6 +2618,53 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|  |     GGML_ASSERT(src0); | ||||||
|  |     GGML_ASSERT(src0->extra); | ||||||
|  |     GGML_ASSERT(dst); | ||||||
|  |     GGML_ASSERT(dst->extra); | ||||||
|  |  | ||||||
|  |     UNUSED(src1); | ||||||
|  |  | ||||||
|  |     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||||||
|  |     cl_command_queue queue = backend_ctx->queue; | ||||||
|  |  | ||||||
|  |     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; | ||||||
|  |     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||||||
|  |  | ||||||
|  |     cl_ulong offset0 = extra0->offset + src0->view_offs; | ||||||
|  |     cl_ulong offsetd = extrad->offset + dst->view_offs; | ||||||
|  |  | ||||||
|  |     cl_kernel kernel; | ||||||
|  |  | ||||||
|  |     int n = ggml_nelements(dst); | ||||||
|  |  | ||||||
|  |     if (n % 4 == 0) { | ||||||
|  |         kernel = backend_ctx->kernel_gelu_quick_4; | ||||||
|  |         n /= 4; | ||||||
|  |     } else { | ||||||
|  |         kernel = backend_ctx->kernel_gelu_quick; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); | ||||||
|  |  | ||||||
|  |     size_t global_work_size[] = {(size_t)n, 1, 1}; | ||||||
|  |     size_t local_work_size[] = {64, 1, 1}; | ||||||
|  |  | ||||||
|  | #ifdef GGML_OPENCL_PROFILING | ||||||
|  |     cl_event evt; | ||||||
|  |     clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); | ||||||
|  |  | ||||||
|  |     g_profiling_info.emplace_back(); | ||||||
|  |     populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); | ||||||
|  | #else | ||||||
|  |     clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
| static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     GGML_ASSERT(src0); |     GGML_ASSERT(src0); | ||||||
|     GGML_ASSERT(src0->extra); |     GGML_ASSERT(src0->extra); | ||||||
| @@ -3980,6 +4063,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | |||||||
|     float attn_factor; |     float attn_factor; | ||||||
|     float beta_fast; |     float beta_fast; | ||||||
|     float beta_slow; |     float beta_slow; | ||||||
|  |     int32_t sections[4]; | ||||||
|  |  | ||||||
|     memcpy(&freq_base,   (int32_t *) dst->op_params + 5, sizeof(float)); |     memcpy(&freq_base,   (int32_t *) dst->op_params + 5, sizeof(float)); | ||||||
|     memcpy(&freq_scale,  (int32_t *) dst->op_params + 6, sizeof(float)); |     memcpy(&freq_scale,  (int32_t *) dst->op_params + 6, sizeof(float)); | ||||||
| @@ -3987,23 +4071,23 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | |||||||
|     memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); |     memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); | ||||||
|     memcpy(&beta_fast,   (int32_t *) dst->op_params + 9, sizeof(float)); |     memcpy(&beta_fast,   (int32_t *) dst->op_params + 9, sizeof(float)); | ||||||
|     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float)); |     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float)); | ||||||
|  |     memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int32_t)*4); | ||||||
|  |  | ||||||
|     const bool is_neox = mode & 2; |     const bool is_neox = mode & 2; | ||||||
|  |     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; | ||||||
|  |     const bool is_vision = mode == GGML_ROPE_TYPE_VISION; | ||||||
|  |  | ||||||
|  |     if (is_mrope) { | ||||||
|  |         GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (is_vision) { | ||||||
|  |         GGML_ASSERT(n_dims == ne00/2); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     cl_kernel kernel; |     cl_kernel kernel; | ||||||
|  |  | ||||||
|     if (!is_neox) { |     if (is_neox) { | ||||||
|         switch (src0->type) { |  | ||||||
|             case GGML_TYPE_F32: |  | ||||||
|                 kernel = backend_ctx->kernel_rope_norm_f32; |  | ||||||
|                 break; |  | ||||||
|             case GGML_TYPE_F16: |  | ||||||
|                 kernel = backend_ctx->kernel_rope_norm_f16; |  | ||||||
|                 break; |  | ||||||
|             default: |  | ||||||
|                 GGML_ASSERT(false); |  | ||||||
|         }; |  | ||||||
|     } else { |  | ||||||
|         switch (src0->type) { |         switch (src0->type) { | ||||||
|             case GGML_TYPE_F32: |             case GGML_TYPE_F32: | ||||||
|                 kernel = backend_ctx->kernel_rope_neox_f32; |                 kernel = backend_ctx->kernel_rope_neox_f32; | ||||||
| @@ -4014,6 +4098,39 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | |||||||
|             default: |             default: | ||||||
|                 GGML_ASSERT(false); |                 GGML_ASSERT(false); | ||||||
|         }; |         }; | ||||||
|  |     } else if (is_mrope && !is_vision) { | ||||||
|  |         switch (src0->type) { | ||||||
|  |             case GGML_TYPE_F32: | ||||||
|  |                 kernel = backend_ctx->kernel_rope_multi_f32; | ||||||
|  |                 break; | ||||||
|  |             case GGML_TYPE_F16: | ||||||
|  |                 kernel = backend_ctx->kernel_rope_multi_f16; | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |         }; | ||||||
|  |     } else if (is_vision) { | ||||||
|  |         switch (src0->type) { | ||||||
|  |             case GGML_TYPE_F32: | ||||||
|  |                 kernel = backend_ctx->kernel_rope_vision_f32; | ||||||
|  |                 break; | ||||||
|  |             case GGML_TYPE_F16: | ||||||
|  |                 kernel = backend_ctx->kernel_rope_vision_f16; | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |         } | ||||||
|  |     } else { | ||||||
|  |         switch (src0->type) { | ||||||
|  |             case GGML_TYPE_F32: | ||||||
|  |                 kernel = backend_ctx->kernel_rope_norm_f32; | ||||||
|  |                 break; | ||||||
|  |             case GGML_TYPE_F16: | ||||||
|  |                 kernel = backend_ctx->kernel_rope_norm_f16; | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |         }; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device)); |     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device)); | ||||||
| @@ -4049,6 +4166,9 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | |||||||
|     CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float),    &attn_factor)); |     CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float),    &attn_factor)); | ||||||
|     CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float),    &beta_fast)); |     CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float),    &beta_fast)); | ||||||
|     CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float),    &beta_slow)); |     CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float),    &beta_slow)); | ||||||
|  |     if (is_mrope || is_vision) { | ||||||
|  |         CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; |     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; | ||||||
|     size_t local_work_size[] = {(size_t)nth, 1, 1}; |     size_t local_work_size[] = {(size_t)nth, 1, 1}; | ||||||
| @@ -4064,6 +4184,98 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|  |     GGML_ASSERT(src0); | ||||||
|  |     GGML_ASSERT(src1); | ||||||
|  |     GGML_ASSERT(src1->extra); | ||||||
|  |     GGML_ASSERT(dst); | ||||||
|  |     GGML_ASSERT(dst->extra); | ||||||
|  |  | ||||||
|  |     // src0 - filter, src1 - input | ||||||
|  |     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||||
|  |     GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|  |     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||||||
|  |     cl_command_queue queue = backend_ctx->queue; | ||||||
|  |  | ||||||
|  |     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; | ||||||
|  |     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||||||
|  |  | ||||||
|  |     cl_ulong offset1 = extra1->offset + src1->view_offs; | ||||||
|  |     cl_ulong offsetd = extrad->offset + dst->view_offs; | ||||||
|  |  | ||||||
|  |     const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; | ||||||
|  |     const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; | ||||||
|  |     const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; | ||||||
|  |     const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; | ||||||
|  |     const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; | ||||||
|  |     const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; | ||||||
|  |  | ||||||
|  |     const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; | ||||||
|  |  | ||||||
|  |     const cl_long IC = src1->ne[is_2D ? 2 : 1]; | ||||||
|  |     const cl_long IH = is_2D ? src1->ne[1] : 1; | ||||||
|  |     const cl_long IW =         src1->ne[0]; | ||||||
|  |  | ||||||
|  |     const cl_long KH = is_2D ? src0->ne[1] : 1; | ||||||
|  |     const cl_long KW =         src0->ne[0]; | ||||||
|  |  | ||||||
|  |     const cl_long OH = is_2D ? dst->ne[2] : 1; | ||||||
|  |     const cl_long OW =         dst->ne[1]; | ||||||
|  |  | ||||||
|  |     // nb is byte offset, src is type float32 | ||||||
|  |     const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4; | ||||||
|  |     const cl_long  batch        = src1->ne[is_2D ? 3 : 2]; | ||||||
|  |     const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4; | ||||||
|  |  | ||||||
|  |     const cl_long pelements = OW*KW*KH; | ||||||
|  |     const cl_long CHW       = IC*KH*KW; | ||||||
|  |  | ||||||
|  |     cl_kernel kernel; | ||||||
|  |  | ||||||
|  |     if(dst->type == GGML_TYPE_F16) { | ||||||
|  |         kernel = backend_ctx->kernel_im2col_f16; | ||||||
|  |     } else { | ||||||
|  |         kernel = backend_ctx->kernel_im2col_f32; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra1->data_device)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset1)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &extrad->data_device)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_ulong), &offsetd)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   4, sizeof(cl_ulong), &batch_offset)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   5, sizeof(cl_ulong), &delta_offset)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   6, sizeof(cl_long),  &IW)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   7, sizeof(cl_long),  &IH)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   8, sizeof(cl_long),  &IC)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_long),  &OW)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_long),  &OH)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_long),  &KW)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_long),  &KH)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  13, sizeof(cl_long),  &pelements)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  14, sizeof(cl_long),  &CHW)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  15, sizeof(int),      &s0)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  16, sizeof(int),      &s1)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  17, sizeof(int),      &p0)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  18, sizeof(int),      &p1)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  19, sizeof(int),      &d0)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  20, sizeof(int),      &d1)); | ||||||
|  |  | ||||||
|  |     const int num_blocks = (pelements + 256 - 1) / 256; | ||||||
|  |     size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC}; | ||||||
|  |     size_t local_work_size[] = {256, 1, 1}; | ||||||
|  |  | ||||||
|  | #ifdef GGML_OPENCL_PROFILING | ||||||
|  |     cl_event evt; | ||||||
|  |     CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); | ||||||
|  |  | ||||||
|  |     g_profiling_info.emplace_back(); | ||||||
|  |     populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); | ||||||
|  | #else | ||||||
|  |     CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
| //------------------------------------------------------------------------------ | //------------------------------------------------------------------------------ | ||||||
| // Op offloading | // Op offloading | ||||||
| //------------------------------------------------------------------------------ | //------------------------------------------------------------------------------ | ||||||
| @@ -4122,6 +4334,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor | |||||||
|                     } |                     } | ||||||
|                     func = ggml_cl_gelu; |                     func = ggml_cl_gelu; | ||||||
|                     break; |                     break; | ||||||
|  |                 case GGML_UNARY_OP_GELU_QUICK: | ||||||
|  |                     if (!any_on_device) { | ||||||
|  |                         return false; | ||||||
|  |                     } | ||||||
|  |                     func = ggml_cl_gelu_quick; | ||||||
|  |                     break; | ||||||
|                 case GGML_UNARY_OP_SILU: |                 case GGML_UNARY_OP_SILU: | ||||||
|                     if (!any_on_device) { |                     if (!any_on_device) { | ||||||
|                         return false; |                         return false; | ||||||
| @@ -4194,6 +4412,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor | |||||||
|             } |             } | ||||||
|             func = ggml_cl_rope; |             func = ggml_cl_rope; | ||||||
|             break; |             break; | ||||||
|  |         case GGML_OP_IM2COL: | ||||||
|  |             if (!any_on_device) { | ||||||
|  |                 return false; | ||||||
|  |             } | ||||||
|  |             func = ggml_cl_im2col; | ||||||
|  |             break; | ||||||
|         default: |         default: | ||||||
|             return false; |             return false; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -404,6 +404,7 @@ kernel void kernel_scale( | |||||||
| // gelu | // gelu | ||||||
| //------------------------------------------------------------------------------ | //------------------------------------------------------------------------------ | ||||||
| #define GELU_COEF_A     0.044715f | #define GELU_COEF_A     0.044715f | ||||||
|  | #define GELU_QUICK_COEF -1.702f | ||||||
| #define SQRT_2_OVER_PI  0.79788456080286535587989211986876f | #define SQRT_2_OVER_PI  0.79788456080286535587989211986876f | ||||||
|  |  | ||||||
| kernel void kernel_gelu( | kernel void kernel_gelu( | ||||||
| @@ -434,6 +435,32 @@ kernel void kernel_gelu_4( | |||||||
|     dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); |     dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | kernel void kernel_gelu_quick( | ||||||
|  |     global float * src0, | ||||||
|  |     ulong offset0, | ||||||
|  |     global float * dst, | ||||||
|  |     ulong offsetd | ||||||
|  | ) { | ||||||
|  |     src0 = (global float*)((global char*)src0 + offset0); | ||||||
|  |     dst = (global float*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     float x = src0[get_global_id(0)]; | ||||||
|  |     dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | kernel void kernel_gelu_quick_4( | ||||||
|  |     global float4 * src0, | ||||||
|  |     ulong offset0, | ||||||
|  |     global float4 * dst, | ||||||
|  |     ulong offsetd | ||||||
|  | ) { | ||||||
|  |     src0 = (global float4*)((global char*)src0 + offset0); | ||||||
|  |     dst = (global float4*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     float4 x = src0[get_global_id(0)]; | ||||||
|  |     dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); | ||||||
|  | } | ||||||
|  |  | ||||||
| //------------------------------------------------------------------------------ | //------------------------------------------------------------------------------ | ||||||
| // silu | // silu | ||||||
| //------------------------------------------------------------------------------ | //------------------------------------------------------------------------------ | ||||||
| @@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | kernel void kernel_rope_multi_f32( | ||||||
|  |         global void * src0, | ||||||
|  |         ulong offset0, | ||||||
|  |         global int * src1, | ||||||
|  |         ulong offset1, | ||||||
|  |         global float * src2, | ||||||
|  |         ulong offset2, | ||||||
|  |         global float * dst, | ||||||
|  |         ulong offsetd, | ||||||
|  |         int ne00, | ||||||
|  |         int ne01, | ||||||
|  |         int ne02, | ||||||
|  |         int ne03, | ||||||
|  |         ulong nb00, | ||||||
|  |         ulong nb01, | ||||||
|  |         ulong nb02, | ||||||
|  |         ulong nb03, | ||||||
|  |         int ne0, | ||||||
|  |         int ne1, | ||||||
|  |         int ne2, | ||||||
|  |         int ne3, | ||||||
|  |         ulong nb0, | ||||||
|  |         ulong nb1, | ||||||
|  |         ulong nb2, | ||||||
|  |         ulong nb3, | ||||||
|  |         int n_past, | ||||||
|  |         int n_dims, | ||||||
|  |         int n_ctx_orig, | ||||||
|  |         float freq_base, | ||||||
|  |         float freq_scale, | ||||||
|  |         float ext_factor, | ||||||
|  |         float attn_factor, | ||||||
|  |         float beta_fast, | ||||||
|  |         float beta_slow, | ||||||
|  |         int4 sections | ||||||
|  | ) { | ||||||
|  |     src0 = (global void*)((global char*)src0 + offset0); | ||||||
|  |     src1 = (global int*)((global char*)src1 + offset1); | ||||||
|  |     src2 = (global float*)((global char*)src2 + offset2); | ||||||
|  |     dst = (global float*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     int i3 = get_group_id(2); | ||||||
|  |     int i2 = get_group_id(1); | ||||||
|  |     int i1 = get_group_id(0); | ||||||
|  |  | ||||||
|  |     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||||
|  |  | ||||||
|  |     global int * pos = src1; | ||||||
|  |  | ||||||
|  |     const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; | ||||||
|  |     const int sec_w = sections.s1 + sections.s0; | ||||||
|  |  | ||||||
|  |     float inv_ndims = -1.f/n_dims; | ||||||
|  |  | ||||||
|  |     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||||
|  |         if (i0 < n_dims) { | ||||||
|  |             int ic = i0/2; | ||||||
|  |  | ||||||
|  |             const int sector = (i0 / 2) % sect_dims; | ||||||
|  |             float theta_base = 0.0f; | ||||||
|  |  | ||||||
|  |             if (sector < sections.s0) { | ||||||
|  |                 theta_base = pos[i2]; | ||||||
|  |             } | ||||||
|  |             else if (sector >= sections.s0 && sector < sec_w) { | ||||||
|  |                 theta_base = pos[i2 + ne2 * 1]; | ||||||
|  |             } | ||||||
|  |             else if (sector >= sec_w && sector < sec_w + sections.s2) { | ||||||
|  |                 theta_base = pos[i2 + ne2 * 2]; | ||||||
|  |             } | ||||||
|  |             else if (sector >= sec_w + sections.s2) { | ||||||
|  |                 theta_base = pos[i2 + ne2 * 3]; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             const float theta = theta_base * pow(freq_base, inv_ndims*i0); | ||||||
|  |  | ||||||
|  |             const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||||
|  |  | ||||||
|  |             float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||||
|  |  | ||||||
|  |             global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||||
|  |             global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||||
|  |  | ||||||
|  |             const float x0 = src[0]; | ||||||
|  |             const float x1 = src[n_dims/2]; | ||||||
|  |  | ||||||
|  |             dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||||
|  |             dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||||
|  |         } else { | ||||||
|  |             global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||||
|  |             global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||||
|  |  | ||||||
|  |             dst_data[0] = src[0]; | ||||||
|  |             dst_data[1] = src[1]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | kernel void kernel_rope_multi_f16( | ||||||
|  |         global void * src0, | ||||||
|  |         ulong offset0, | ||||||
|  |         global int * src1, | ||||||
|  |         ulong offset1, | ||||||
|  |         global float * src2, | ||||||
|  |         ulong offset2, | ||||||
|  |         global half * dst, | ||||||
|  |         ulong offsetd, | ||||||
|  |         int ne00, | ||||||
|  |         int ne01, | ||||||
|  |         int ne02, | ||||||
|  |         int ne03, | ||||||
|  |         ulong nb00, | ||||||
|  |         ulong nb01, | ||||||
|  |         ulong nb02, | ||||||
|  |         ulong nb03, | ||||||
|  |         int ne0, | ||||||
|  |         int ne1, | ||||||
|  |         int ne2, | ||||||
|  |         int ne3, | ||||||
|  |         ulong nb0, | ||||||
|  |         ulong nb1, | ||||||
|  |         ulong nb2, | ||||||
|  |         ulong nb3, | ||||||
|  |         int n_past, | ||||||
|  |         int n_dims, | ||||||
|  |         int n_ctx_orig, | ||||||
|  |         float freq_base, | ||||||
|  |         float freq_scale, | ||||||
|  |         float ext_factor, | ||||||
|  |         float attn_factor, | ||||||
|  |         float beta_fast, | ||||||
|  |         float beta_slow, | ||||||
|  |         int4 sections | ||||||
|  | ) { | ||||||
|  |     src0 = (global void*)((global char*)src0 + offset0); | ||||||
|  |     src1 = (global int*)((global char*)src1 + offset1); | ||||||
|  |     src2 = (global float*)((global char*)src2 + offset2); | ||||||
|  |     dst = (global float*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     int i3 = get_group_id(2); | ||||||
|  |     int i2 = get_group_id(1); | ||||||
|  |     int i1 = get_group_id(0); | ||||||
|  |  | ||||||
|  |     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||||
|  |  | ||||||
|  |     global int * pos = src1; | ||||||
|  |  | ||||||
|  |     const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; | ||||||
|  |     const int sec_w = sections.s1 + sections.s0; | ||||||
|  |  | ||||||
|  |     float inv_ndims = -1.f/n_dims; | ||||||
|  |  | ||||||
|  |     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||||
|  |         if (i0 < n_dims) { | ||||||
|  |             int ic = i0/2; | ||||||
|  |  | ||||||
|  |             const int sector = (i0 / 2) % sect_dims; | ||||||
|  |             float theta_base = 0.0f; | ||||||
|  |  | ||||||
|  |             if (sector < sections.s0) { | ||||||
|  |                 theta_base = pos[i2]; | ||||||
|  |             } | ||||||
|  |             else if (sector >= sections.s0 && sector < sec_w) { | ||||||
|  |                 theta_base = pos[i2 + ne2 * 1]; | ||||||
|  |             } | ||||||
|  |             else if (sector >= sec_w && sector < sec_w + sections.s2) { | ||||||
|  |                 theta_base = pos[i2 + ne2 * 2]; | ||||||
|  |             } | ||||||
|  |             else if (sector >= sec_w + sections.s2) { | ||||||
|  |                 theta_base = pos[i2 + ne2 * 3]; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             const float theta = theta_base * pow(freq_base, inv_ndims*i0); | ||||||
|  |  | ||||||
|  |             const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||||
|  |  | ||||||
|  |             float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||||
|  |  | ||||||
|  |             global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||||
|  |             global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||||
|  |  | ||||||
|  |             const float x0 = src[0]; | ||||||
|  |             const float x1 = src[n_dims/2]; | ||||||
|  |  | ||||||
|  |             dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||||
|  |             dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||||
|  |         } else { | ||||||
|  |             global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||||
|  |             global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||||
|  |  | ||||||
|  |             dst_data[0] = src[0]; | ||||||
|  |             dst_data[1] = src[1]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | kernel void kernel_rope_vision_f32( | ||||||
|  |         global void * src0, | ||||||
|  |         ulong offset0, | ||||||
|  |         global int * src1, | ||||||
|  |         ulong offset1, | ||||||
|  |         global float * src2, | ||||||
|  |         ulong offset2, | ||||||
|  |         global float * dst, | ||||||
|  |         ulong offsetd, | ||||||
|  |         int ne00, | ||||||
|  |         int ne01, | ||||||
|  |         int ne02, | ||||||
|  |         int ne03, | ||||||
|  |         ulong nb00, | ||||||
|  |         ulong nb01, | ||||||
|  |         ulong nb02, | ||||||
|  |         ulong nb03, | ||||||
|  |         int ne0, | ||||||
|  |         int ne1, | ||||||
|  |         int ne2, | ||||||
|  |         int ne3, | ||||||
|  |         ulong nb0, | ||||||
|  |         ulong nb1, | ||||||
|  |         ulong nb2, | ||||||
|  |         ulong nb3, | ||||||
|  |         int n_past, | ||||||
|  |         int n_dims, | ||||||
|  |         int n_ctx_orig, | ||||||
|  |         float freq_base, | ||||||
|  |         float freq_scale, | ||||||
|  |         float ext_factor, | ||||||
|  |         float attn_factor, | ||||||
|  |         float beta_fast, | ||||||
|  |         float beta_slow, | ||||||
|  |         int4 sections | ||||||
|  | ) { | ||||||
|  |     src0 = (global void*)((global char*)src0 + offset0); | ||||||
|  |     src1 = (global int*)((global char*)src1 + offset1); | ||||||
|  |     src2 = (global float*)((global char*)src2 + offset2); | ||||||
|  |     dst = (global float*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     int i3 = get_group_id(2); | ||||||
|  |     int i2 = get_group_id(1); | ||||||
|  |     int i1 = get_group_id(0); | ||||||
|  |  | ||||||
|  |     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||||
|  |  | ||||||
|  |     global int * pos = src1; | ||||||
|  |  | ||||||
|  |     const int sect_dims = sections.s0 + sections.s1; | ||||||
|  |     const int sec_w = sections.s1 + sections.s0; | ||||||
|  |  | ||||||
|  |     float inv_ndims = -1.f/n_dims; | ||||||
|  |  | ||||||
|  |     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||||
|  |         int ic = i0/2; | ||||||
|  |  | ||||||
|  |         const int sector = (i0/2) % sect_dims; | ||||||
|  |         float theta_base = 0.0f; | ||||||
|  |  | ||||||
|  |         if (sector < sections.s0) { | ||||||
|  |             const int p = sector; | ||||||
|  |             theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); | ||||||
|  |         } else if (sector >= sections.s0 && sector < sec_w) { | ||||||
|  |             const int p = sector - sections.s0; | ||||||
|  |             theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||||
|  |  | ||||||
|  |         float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||||
|  |  | ||||||
|  |         global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||||
|  |         global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||||
|  |  | ||||||
|  |         const float x0 = src[0]; | ||||||
|  |         const float x1 = src[n_dims]; | ||||||
|  |  | ||||||
|  |         dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||||
|  |         dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | kernel void kernel_rope_vision_f16( | ||||||
|  |         global void * src0, | ||||||
|  |         ulong offset0, | ||||||
|  |         global int * src1, | ||||||
|  |         ulong offset1, | ||||||
|  |         global float * src2, | ||||||
|  |         ulong offset2, | ||||||
|  |         global half * dst, | ||||||
|  |         ulong offsetd, | ||||||
|  |         int ne00, | ||||||
|  |         int ne01, | ||||||
|  |         int ne02, | ||||||
|  |         int ne03, | ||||||
|  |         ulong nb00, | ||||||
|  |         ulong nb01, | ||||||
|  |         ulong nb02, | ||||||
|  |         ulong nb03, | ||||||
|  |         int ne0, | ||||||
|  |         int ne1, | ||||||
|  |         int ne2, | ||||||
|  |         int ne3, | ||||||
|  |         ulong nb0, | ||||||
|  |         ulong nb1, | ||||||
|  |         ulong nb2, | ||||||
|  |         ulong nb3, | ||||||
|  |         int n_past, | ||||||
|  |         int n_dims, | ||||||
|  |         int n_ctx_orig, | ||||||
|  |         float freq_base, | ||||||
|  |         float freq_scale, | ||||||
|  |         float ext_factor, | ||||||
|  |         float attn_factor, | ||||||
|  |         float beta_fast, | ||||||
|  |         float beta_slow, | ||||||
|  |         int4 sections | ||||||
|  | ) { | ||||||
|  |     src0 = (global void*)((global char*)src0 + offset0); | ||||||
|  |     src1 = (global int*)((global char*)src1 + offset1); | ||||||
|  |     src2 = (global float*)((global char*)src2 + offset2); | ||||||
|  |     dst = (global float*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     int i3 = get_group_id(2); | ||||||
|  |     int i2 = get_group_id(1); | ||||||
|  |     int i1 = get_group_id(0); | ||||||
|  |  | ||||||
|  |     float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); | ||||||
|  |  | ||||||
|  |     global int * pos = src1; | ||||||
|  |  | ||||||
|  |     const int sect_dims = sections.s0 + sections.s1; | ||||||
|  |     const int sec_w = sections.s1 + sections.s0; | ||||||
|  |  | ||||||
|  |     float inv_ndims = -1.f/n_dims; | ||||||
|  |  | ||||||
|  |     for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { | ||||||
|  |         int ic = i0/2; | ||||||
|  |  | ||||||
|  |         const int sector = (i0/2) % sect_dims; | ||||||
|  |         float theta_base = 0.0f; | ||||||
|  |  | ||||||
|  |         if (sector < sections.s0) { | ||||||
|  |             const int p = sector; | ||||||
|  |             theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); | ||||||
|  |         } else if (sector >= sections.s0 && sector < sec_w) { | ||||||
|  |             const int p = sector - sections.s0; | ||||||
|  |             theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; | ||||||
|  |  | ||||||
|  |         float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); | ||||||
|  |  | ||||||
|  |         global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); | ||||||
|  |         global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0); | ||||||
|  |  | ||||||
|  |         const float x0 = src[0]; | ||||||
|  |         const float x1 = src[n_dims]; | ||||||
|  |  | ||||||
|  |         dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; | ||||||
|  |         dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| //------------------------------------------------------------------------------ | //------------------------------------------------------------------------------ | ||||||
| // cpy | // cpy | ||||||
| //------------------------------------------------------------------------------ | //------------------------------------------------------------------------------ | ||||||
|   | |||||||
							
								
								
									
										146
									
								
								ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,146 @@ | |||||||
|  | #ifdef cl_khr_fp16 | ||||||
|  | #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||||||
|  | #elif defined(cl_amd_fp16) | ||||||
|  | #pragma OPENCL EXTENSION cl_amd_fp16 : enable | ||||||
|  | #else | ||||||
|  | #error "Half precision floating point not supportedby OpenCL implementation on your device." | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #ifdef cl_khr_subgroups | ||||||
|  | #pragma OPENCL EXTENSION cl_khr_subgroups : enable | ||||||
|  | #elif defined(cl_intel_subgroups) | ||||||
|  | #pragma OPENCL EXTENSION cl_intel_subgroups : enable | ||||||
|  | #else | ||||||
|  | #error "Subgroup not supported on your device." | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #ifdef cl_intel_required_subgroup_size | ||||||
|  | // Always use subgroup size of 32 on Intel. | ||||||
|  | #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable | ||||||
|  | #define INTEL_GPU 1 | ||||||
|  | #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) | ||||||
|  | #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) | ||||||
|  | #elif defined(cl_qcom_reqd_sub_group_size) | ||||||
|  | // Always use subgroups size of 64 on Adreno. | ||||||
|  | #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable | ||||||
|  | #define ADRENO_GPU 1 | ||||||
|  | #define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half"))) | ||||||
|  | #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) | ||||||
|  | #else | ||||||
|  | // TODO: do not know how to choose subgroup size on other GPUs. | ||||||
|  | #error "Selecting subgroup size is not supported on your device." | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | kernel void kernel_im2col_f32( | ||||||
|  |         global float * src1, | ||||||
|  |         ulong offset1, | ||||||
|  |         global float * dst, | ||||||
|  |         ulong offsetd, | ||||||
|  |         ulong batch_offset, | ||||||
|  |         ulong delta_offset, | ||||||
|  |         long IW, | ||||||
|  |         long IH, | ||||||
|  |         long IC, | ||||||
|  |         long OW, | ||||||
|  |         long OH, | ||||||
|  |         long KW, | ||||||
|  |         long KH, | ||||||
|  |         long pelements, | ||||||
|  |         long CHW, | ||||||
|  |         int  s0, | ||||||
|  |         int  s1, | ||||||
|  |         int  p0, | ||||||
|  |         int  p1, | ||||||
|  |         int  d0, | ||||||
|  |         int  d1 | ||||||
|  | ) { | ||||||
|  |     // threadIdx.x + blockIdx.x * blockDim.x | ||||||
|  |     long i = get_global_id(0); | ||||||
|  |     if (i >= pelements) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     src1 = (global float*)((global char*)src1 + offset1); | ||||||
|  |     dst = (global float*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     long  ksize = OW * (KH > 1 ? KW : 1); | ||||||
|  |     long  kx = i / ksize; | ||||||
|  |     long  kd = kx * ksize; | ||||||
|  |     long  ky = (i - kd) / OW; | ||||||
|  |     long  ix = i % OW; | ||||||
|  |  | ||||||
|  |     long  oh = get_group_id(1); | ||||||
|  |     long  batch = get_group_id(2) / IC; | ||||||
|  |     long  ic = get_group_id(2) % IC; | ||||||
|  |  | ||||||
|  |     long iiw = ix * s0 + kx * d0 - p0; | ||||||
|  |     long iih = oh * s1 + ky * d1 - p1; | ||||||
|  |  | ||||||
|  |     long offset_dst = | ||||||
|  |         ((batch * OH + oh) * OW + ix) * CHW + | ||||||
|  |         (ic * (KW * KH) + ky * KW + kx); | ||||||
|  |  | ||||||
|  |     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { | ||||||
|  |         dst[offset_dst] = 0.0f; | ||||||
|  |     } else { | ||||||
|  |         long offset_src = ic * delta_offset + batch * batch_offset; | ||||||
|  |         dst[offset_dst] = src1[offset_src + iih * IW + iiw]; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | kernel void kernel_im2col_f16( | ||||||
|  |         global float * src1, | ||||||
|  |         ulong offset1, | ||||||
|  |         global half  * dst, | ||||||
|  |         ulong offsetd, | ||||||
|  |         ulong batch_offset, | ||||||
|  |         ulong delta_offset, | ||||||
|  |         long IW, | ||||||
|  |         long IH, | ||||||
|  |         long IC, | ||||||
|  |         long OW, | ||||||
|  |         long OH, | ||||||
|  |         long KW, | ||||||
|  |         long KH, | ||||||
|  |         long pelements, | ||||||
|  |         long CHW, | ||||||
|  |         int  s0, | ||||||
|  |         int  s1, | ||||||
|  |         int  p0, | ||||||
|  |         int  p1, | ||||||
|  |         int  d0, | ||||||
|  |         int  d1 | ||||||
|  | ) { | ||||||
|  |     long i = get_global_id(0); | ||||||
|  |  | ||||||
|  |     if (i >= pelements) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     src1 = (global float*)((global char*)src1 + offset1); | ||||||
|  |     dst = (global half*)((global char*)dst + offsetd); | ||||||
|  |  | ||||||
|  |     long  ksize = OW * (KH > 1 ? KW : 1); | ||||||
|  |     long  kx = i / ksize; | ||||||
|  |     long  kd = kx * ksize; | ||||||
|  |     long  ky = (i - kd) / OW; | ||||||
|  |     long  ix = i % OW; | ||||||
|  |  | ||||||
|  |     long  oh = get_group_id(1); | ||||||
|  |     long  batch = get_group_id(2) / IC; | ||||||
|  |     long  ic = get_group_id(2) % IC; | ||||||
|  |  | ||||||
|  |     long iiw = ix * s0 + kx * d0 - p0; | ||||||
|  |     long iih = oh * s1 + ky * d1 - p1; | ||||||
|  |  | ||||||
|  |     long offset_dst = | ||||||
|  |         ((batch * OH + oh) * OW + ix) * CHW + | ||||||
|  |         (ic * (KW * KH) + ky * KW + kx); | ||||||
|  |  | ||||||
|  |     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { | ||||||
|  |         dst[offset_dst] = 0.0f; | ||||||
|  |     } else { | ||||||
|  |         long offset_src = ic * delta_offset + batch * batch_offset; | ||||||
|  |         dst[offset_dst] = src1[offset_src + iih * IW + iiw]; | ||||||
|  |     } | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user
	 lhez
					lhez