From c5023daf607c578d6344c628eb7da18ac3d92d32 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 3 Nov 2025 11:47:57 -0800 Subject: [PATCH] opencl: support imrope (#16914) * opencl: support imrope * opencl: fix whitespace --- ggml/src/ggml-opencl/ggml-opencl.cpp | 6 +++ ggml/src/ggml-opencl/kernels/rope.cl | 74 +++++++++++++++++++--------- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 93a3600b63..3dc4d03550 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const const bool is_neox = mode & 2; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE; if (is_mrope) { GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); @@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor)); CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast)); CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow)); + // both mrope and vision kernels have sections if (is_mrope || is_vision) { CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions)); } + // only mrope has is_imrope + if (is_mrope && !is_vision) { + CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope)); + } size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/rope.cl b/ggml/src/ggml-opencl/kernels/rope.cl index 0247730c03..82f4cd8740 100644 --- a/ggml/src/ggml-opencl/kernels/rope.cl +++ b/ggml/src/ggml-opencl/kernels/rope.cl @@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32( float attn_factor, float beta_fast, float beta_slow, - int4 sections + int4 sections, + int is_imrope ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32( const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0f; - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.s1) { // h + theta_base = (float) pos[i2 + ne02 * 1]; + } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w + theta_base = (float) pos[i2 + ne02 * 2]; + } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t + theta_base = (float) pos[i2 + ne02 * 0]; + } else { // e + theta_base = (float) pos[i2 + ne02 * 3]; + } + } else { + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } } const float theta = theta_base * pow(freq_base, inv_ndims*i0); @@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16( float attn_factor, float beta_fast, float beta_slow, - int4 sections + int4 sections, + int is_imrope ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16( const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0f; - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.s1) { // h + theta_base = (float) pos[i2 + ne02 * 1]; + } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w + theta_base = (float) pos[i2 + ne02 * 2]; + } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t + theta_base = (float) pos[i2 + ne02 * 0]; + } else { // e + theta_base = (float) pos[i2 + ne02 * 3]; + } + } else { + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } } const float theta = theta_base * pow(freq_base, inv_ndims*i0);