mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-09 10:17:06 +00:00
opencl: support imrope (#16914)
* opencl: support imrope * opencl: fix whitespace
This commit is contained in:
@@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||||
|
|
||||||
if (is_mrope) {
|
if (is_mrope) {
|
||||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||||
@@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||||||
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
|
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
|
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
|
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
|
||||||
|
// both mrope and vision kernels have sections
|
||||||
if (is_mrope || is_vision) {
|
if (is_mrope || is_vision) {
|
||||||
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions));
|
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions));
|
||||||
}
|
}
|
||||||
|
// only mrope has is_imrope
|
||||||
|
if (is_mrope && !is_vision) {
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
|
||||||
|
}
|
||||||
|
|
||||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||||
|
|||||||
@@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
|
|||||||
float attn_factor,
|
float attn_factor,
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow,
|
float beta_slow,
|
||||||
int4 sections
|
int4 sections,
|
||||||
|
int is_imrope
|
||||||
) {
|
) {
|
||||||
src0 = (global void*)((global char*)src0 + offset0);
|
src0 = (global void*)((global char*)src0 + offset0);
|
||||||
src1 = (global int*)((global char*)src1 + offset1);
|
src1 = (global int*)((global char*)src1 + offset1);
|
||||||
@@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
|
|||||||
const int sector = (i0 / 2) % sect_dims;
|
const int sector = (i0 / 2) % sect_dims;
|
||||||
float theta_base = 0.0f;
|
float theta_base = 0.0f;
|
||||||
|
|
||||||
if (sector < sections.s0) {
|
if (is_imrope) {
|
||||||
theta_base = pos[i2];
|
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
|
||||||
}
|
theta_base = (float) pos[i2 + ne02 * 1];
|
||||||
else if (sector >= sections.s0 && sector < sec_w) {
|
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
|
||||||
theta_base = pos[i2 + ne2 * 1];
|
theta_base = (float) pos[i2 + ne02 * 2];
|
||||||
}
|
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
|
||||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
theta_base = (float) pos[i2 + ne02 * 0];
|
||||||
theta_base = pos[i2 + ne2 * 2];
|
} else { // e
|
||||||
}
|
theta_base = (float) pos[i2 + ne02 * 3];
|
||||||
else if (sector >= sec_w + sections.s2) {
|
}
|
||||||
theta_base = pos[i2 + ne2 * 3];
|
} else {
|
||||||
|
if (sector < sections.s0) {
|
||||||
|
theta_base = pos[i2];
|
||||||
|
}
|
||||||
|
else if (sector >= sections.s0 && sector < sec_w) {
|
||||||
|
theta_base = pos[i2 + ne2 * 1];
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||||
|
theta_base = pos[i2 + ne2 * 2];
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w + sections.s2) {
|
||||||
|
theta_base = pos[i2 + ne2 * 3];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||||
@@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
|
|||||||
float attn_factor,
|
float attn_factor,
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow,
|
float beta_slow,
|
||||||
int4 sections
|
int4 sections,
|
||||||
|
int is_imrope
|
||||||
) {
|
) {
|
||||||
src0 = (global void*)((global char*)src0 + offset0);
|
src0 = (global void*)((global char*)src0 + offset0);
|
||||||
src1 = (global int*)((global char*)src1 + offset1);
|
src1 = (global int*)((global char*)src1 + offset1);
|
||||||
@@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
|
|||||||
const int sector = (i0 / 2) % sect_dims;
|
const int sector = (i0 / 2) % sect_dims;
|
||||||
float theta_base = 0.0f;
|
float theta_base = 0.0f;
|
||||||
|
|
||||||
if (sector < sections.s0) {
|
if (is_imrope) {
|
||||||
theta_base = pos[i2];
|
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
|
||||||
}
|
theta_base = (float) pos[i2 + ne02 * 1];
|
||||||
else if (sector >= sections.s0 && sector < sec_w) {
|
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
|
||||||
theta_base = pos[i2 + ne2 * 1];
|
theta_base = (float) pos[i2 + ne02 * 2];
|
||||||
}
|
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
|
||||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
theta_base = (float) pos[i2 + ne02 * 0];
|
||||||
theta_base = pos[i2 + ne2 * 2];
|
} else { // e
|
||||||
}
|
theta_base = (float) pos[i2 + ne02 * 3];
|
||||||
else if (sector >= sec_w + sections.s2) {
|
}
|
||||||
theta_base = pos[i2 + ne2 * 3];
|
} else {
|
||||||
|
if (sector < sections.s0) {
|
||||||
|
theta_base = pos[i2];
|
||||||
|
}
|
||||||
|
else if (sector >= sections.s0 && sector < sec_w) {
|
||||||
|
theta_base = pos[i2 + ne2 * 1];
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||||
|
theta_base = pos[i2 + ne2 * 2];
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w + sections.s2) {
|
||||||
|
theta_base = pos[i2 + ne2 * 3];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||||
|
|||||||
Reference in New Issue
Block a user