mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	| @@ -3168,11 +3168,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor | |||||||
|     ggml_sycl_op_diag_mask_inf(ctx, dst); |     ggml_sycl_op_diag_mask_inf(ctx, dst); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |  | ||||||
|     GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented |  | ||||||
|     ggml_sycl_op_rope(ctx, dst); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | ||||||
|     ggml_sycl_op_pool2d(ctx, dst); |     ggml_sycl_op_pool2d(ctx, dst); | ||||||
| } | } | ||||||
| @@ -4002,7 +3997,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
|                 if (mode == GGML_ROPE_TYPE_MROPE) { |                 if (mode == GGML_ROPE_TYPE_MROPE) { | ||||||
|                     return false; |                     return false; | ||||||
|                 } |                 } | ||||||
|                 return ggml_is_contiguous(op->src[0]); |                 return true; | ||||||
|             } |             } | ||||||
|         case GGML_OP_IM2COL: |         case GGML_OP_IM2COL: | ||||||
|             return true; |             return true; | ||||||
|   | |||||||
| @@ -34,23 +34,21 @@ static void rope_yarn( | |||||||
|     *sin_theta = sycl::sin(theta) * mscale; |     *sin_theta = sycl::sin(theta) * mscale; | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename T, bool has_ff> | template <typename T, bool has_ff> | ||||||
| static void rope_norm( | static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, | ||||||
|     const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, |                       const int32_t * pos, float freq_scale, float ext_factor, float attn_factor, | ||||||
|     float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, |                       const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, | ||||||
|     const sycl::nd_item<3> &item_ct1) { |                       const sycl::nd_item<3> & item_ct1) { | ||||||
|     const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + |     const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); | ||||||
|                          item_ct1.get_local_id(1)); |  | ||||||
|  |  | ||||||
|     if (i0 >= ne0) { |     if (i0 >= ne0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + |     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); | ||||||
|                     item_ct1.get_local_id(2); |  | ||||||
|  |  | ||||||
|     if (i0 >= n_dims) { |     if (i0 >= n_dims) { | ||||||
|         const int i = row*ne0 + i0; |         const int i = row * ne0 + i0; | ||||||
|  |  | ||||||
|         dst[i + 0] = x[i + 0]; |         dst[i + 0] = x[i + 0]; | ||||||
|         dst[i + 1] = x[i + 1]; |         dst[i + 1] = x[i + 1]; | ||||||
| @@ -58,42 +56,43 @@ static void rope_norm( | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const int i = row*ne0 + i0; |     const int row0     = row % ne1; | ||||||
|     const int i2 = row/p_delta_rows; |     const int channel0 = row / ne1; | ||||||
|  |  | ||||||
|     const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f); |     const int i  = row * ne0 + i0; | ||||||
|  |     const int i2 = channel0 * s2 + row0 * s1 + i0; | ||||||
|  |  | ||||||
|     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; |     const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); | ||||||
|  |  | ||||||
|  |     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; | ||||||
|  |  | ||||||
|     float cos_theta; |     float cos_theta; | ||||||
|     float sin_theta; |     float sin_theta; | ||||||
|  |  | ||||||
|     rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); |     rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); | ||||||
|  |  | ||||||
|     const float x0 = x[i + 0]; |     const float x0 = x[i2 + 0]; | ||||||
|     const float x1 = x[i + 1]; |     const float x1 = x[i2 + 1]; | ||||||
|  |  | ||||||
|     dst[i + 0] = x0*cos_theta - x1*sin_theta; |     dst[i + 0] = x0 * cos_theta - x1 * sin_theta; | ||||||
|     dst[i + 1] = x0*sin_theta + x1*cos_theta; |     dst[i + 1] = x0 * sin_theta + x1 * cos_theta; | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename T, bool has_ff> | template <typename T, bool has_ff> | ||||||
| static void rope_neox( | static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, | ||||||
|     const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, |                       const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, | ||||||
|     float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, |                       const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, | ||||||
|     const sycl::nd_item<3> &item_ct1) { |                       const sycl::nd_item<3> & item_ct1) { | ||||||
|     const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + |     const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); | ||||||
|                          item_ct1.get_local_id(1)); |  | ||||||
|  |  | ||||||
|     if (i0 >= ne0) { |     if (i0 >= ne0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + |     const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); | ||||||
|                     item_ct1.get_local_id(2); |  | ||||||
|  |  | ||||||
|     if (i0 >= n_dims) { |     if (i0 >= n_dims) { | ||||||
|         const int i = row*ne0 + i0; |         const int i = row * ne0 + i0; | ||||||
|  |  | ||||||
|         dst[i + 0] = x[i + 0]; |         dst[i + 0] = x[i + 0]; | ||||||
|         dst[i + 1] = x[i + 1]; |         dst[i + 1] = x[i + 1]; | ||||||
| @@ -101,23 +100,26 @@ static void rope_neox( | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const int i  = row*ne0 + i0/2; |     const int row0     = row % ne1; | ||||||
|     const int i2 = row/p_delta_rows; |     const int channel0 = row / ne1; | ||||||
|  |  | ||||||
|     const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f); |     const int i  = row * ne0 + i0 / 2; | ||||||
|  |     const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; | ||||||
|  |  | ||||||
|     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; |     const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); | ||||||
|  |  | ||||||
|  |     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; | ||||||
|  |  | ||||||
|     float cos_theta; |     float cos_theta; | ||||||
|     float sin_theta; |     float sin_theta; | ||||||
|  |  | ||||||
|     rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); |     rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); | ||||||
|  |  | ||||||
|     const float x0 = x[i + 0]; |     const float x0 = x[i2 + 0]; | ||||||
|     const float x1 = x[i + n_dims/2]; |     const float x1 = x[i2 + n_dims / 2]; | ||||||
|  |  | ||||||
|     dst[i + 0]        = x0*cos_theta - x1*sin_theta; |     dst[i + 0]          = x0 * cos_theta - x1 * sin_theta; | ||||||
|     dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; |     dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T, bool has_ff> | template <typename T, bool has_ff> | ||||||
| @@ -163,18 +165,18 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons | |||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| static void rope_norm_sycl( | static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, | ||||||
|     const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, |                            const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base, | ||||||
|     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { |                            const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, | ||||||
|  |                            const float * freq_factors, queue_ptr stream) { | ||||||
|     GGML_ASSERT(ne0 % 2 == 0); |     GGML_ASSERT(ne0 % 2 == 0); | ||||||
|     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); |     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); | ||||||
|     const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); |     const int            num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); | ||||||
|     const sycl::range<3> block_nums(1, num_blocks_x, nr); |     const sycl::range<3> block_nums(1, num_blocks_x, nr); | ||||||
|  |  | ||||||
|     const float theta_scale = powf(freq_base, -2.0f/n_dims); |     const float theta_scale = powf(freq_base, -2.0f / n_dims); | ||||||
|  |  | ||||||
|     dpct::has_capability_or_fail(stream->get_device(), |     dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); | ||||||
|                                      {sycl::aspect::fp16}); |  | ||||||
|  |  | ||||||
|     if (freq_factors == nullptr) { |     if (freq_factors == nullptr) { | ||||||
|         /* |         /* | ||||||
| @@ -182,12 +184,9 @@ static void rope_norm_sycl( | |||||||
|         the limit. To get the device limit, query |         the limit. To get the device limit, query | ||||||
|         info::device::max_work_group_size. Adjust the work-group size if needed. |         info::device::max_work_group_size. Adjust the work-group size if needed. | ||||||
|         */ |         */ | ||||||
|         stream->parallel_for( |         stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { | ||||||
|             sycl::nd_range<3>(block_nums * block_dims, block_dims), |             rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, | ||||||
|             [=](sycl::nd_item<3> item_ct1) { |                                 theta_scale, freq_factors, item_ct1); | ||||||
|                 rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, |  | ||||||
|                                ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, |  | ||||||
|                                item_ct1); |  | ||||||
|         }); |         }); | ||||||
|     } else { |     } else { | ||||||
|         /* |         /* | ||||||
| @@ -195,47 +194,36 @@ static void rope_norm_sycl( | |||||||
|         the limit. To get the device limit, query |         the limit. To get the device limit, query | ||||||
|         info::device::max_work_group_size. Adjust the work-group size if needed. |         info::device::max_work_group_size. Adjust the work-group size if needed. | ||||||
|         */ |         */ | ||||||
|         stream->parallel_for( |         stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { | ||||||
|             sycl::nd_range<3>(block_nums * block_dims, block_dims), |             rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, | ||||||
|             [=](sycl::nd_item<3> item_ct1) { |                                theta_scale, freq_factors, item_ct1); | ||||||
|                 rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, |  | ||||||
|                               ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, |  | ||||||
|                               item_ct1); |  | ||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| static void rope_neox_sycl( | static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, | ||||||
|     const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, |                            const int n_dims, const int nr, const int32_t * pos, const float freq_scale, | ||||||
|     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { |                            const float freq_base, const float ext_factor, const float attn_factor, | ||||||
|  |                            const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { | ||||||
|     GGML_ASSERT(ne0 % 2 == 0); |     GGML_ASSERT(ne0 % 2 == 0); | ||||||
|     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); |     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); | ||||||
|     const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); |     const int            num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); | ||||||
|     const sycl::range<3> block_nums(1, num_blocks_x, nr); |     const sycl::range<3> block_nums(1, num_blocks_x, nr); | ||||||
|  |  | ||||||
|     const float theta_scale = powf(freq_base, -2.0f/n_dims); |     const float theta_scale = powf(freq_base, -2.0f / n_dims); | ||||||
|  |  | ||||||
|     dpct::has_capability_or_fail(stream->get_device(), |     dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); | ||||||
|                                     {sycl::aspect::fp16}); |  | ||||||
|  |  | ||||||
|     if (freq_factors == nullptr) { |     if (freq_factors == nullptr) { | ||||||
|         stream->parallel_for( |         stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { | ||||||
|             sycl::nd_range<3>(block_nums * block_dims, block_dims), |             rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, | ||||||
|             [=](sycl::nd_item<3> item_ct1) { |                                 theta_scale, freq_factors, item_ct1); | ||||||
|                 rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale, |  | ||||||
|                                     p_delta_rows, ext_factor, attn_factor, |  | ||||||
|                                     corr_dims, theta_scale, freq_factors, |  | ||||||
|                                     item_ct1); |  | ||||||
|         }); |         }); | ||||||
|     } else { |     } else { | ||||||
|         stream->parallel_for( |         stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { | ||||||
|             sycl::nd_range<3>(block_nums * block_dims, block_dims), |             rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, | ||||||
|             [=](sycl::nd_item<3> item_ct1) { |                                theta_scale, freq_factors, item_ct1); | ||||||
|                 rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale, |  | ||||||
|                                     p_delta_rows, ext_factor, attn_factor, |  | ||||||
|                                     corr_dims, theta_scale, freq_factors, |  | ||||||
|                                     item_ct1); |  | ||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { | inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { | ||||||
|  |  | ||||||
|     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); |     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); | ||||||
|     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16); |     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16); | ||||||
| @@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { | |||||||
|     if (is_neox) { |     if (is_neox) { | ||||||
|         GGML_SYCL_DEBUG("%s: neox path\n", __func__); |         GGML_SYCL_DEBUG("%s: neox path\n", __func__); | ||||||
|         if (dst->src[0]->type == GGML_TYPE_F32) { |         if (dst->src[0]->type == GGML_TYPE_F32) { | ||||||
|             rope_neox_sycl( |             rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, | ||||||
|                 (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, |                            pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); | ||||||
|                 attn_factor, corr_dims, freq_factors, main_stream |  | ||||||
|             ); |  | ||||||
|         } else if (dst->src[0]->type == GGML_TYPE_F16) { |         } else if (dst->src[0]->type == GGML_TYPE_F16) { | ||||||
|             rope_neox_sycl( |             rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, | ||||||
|                 (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, |                            n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, | ||||||
|                 attn_factor, corr_dims, freq_factors, main_stream |                            main_stream); | ||||||
|             ); |  | ||||||
|         } else { |         } else { | ||||||
|             GGML_ABORT("fatal error"); |             GGML_ABORT("fatal error"); | ||||||
|         } |         } | ||||||
|     } else if (is_vision) { |     } else if (is_vision) { | ||||||
|         GGML_SYCL_DEBUG("%s: vision path\n", __func__); |         GGML_SYCL_DEBUG("%s: vision path\n", __func__); | ||||||
|         if (dst->src[0]->type == GGML_TYPE_F16) { |         if (dst->src[0]->type == GGML_TYPE_F16) { | ||||||
|             rope_vision_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, |             rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01, | ||||||
|                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream); |                              s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, | ||||||
|  |                              freq_factors, sections, main_stream); | ||||||
|         } else if (dst->src[0]->type == GGML_TYPE_F32) { |         } else if (dst->src[0]->type == GGML_TYPE_F32) { | ||||||
|             rope_vision_sycl((const float *) dst->src[0]->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, |             rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, | ||||||
|                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream); |                              nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, | ||||||
|  |                              main_stream); | ||||||
|         } else { |         } else { | ||||||
|             GGML_ABORT("Fatal error: Tensor type unsupported!"); |             GGML_ABORT("Fatal error: Tensor type unsupported!"); | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         GGML_SYCL_DEBUG("%s: norm path\n", __func__); |         GGML_SYCL_DEBUG("%s: norm path\n", __func__); | ||||||
|         if (dst->src[0]->type == GGML_TYPE_F32) { |         if (dst->src[0]->type == GGML_TYPE_F32) { | ||||||
|             rope_norm_sycl( |             rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, | ||||||
|                 (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, |                            pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); | ||||||
|                 attn_factor, corr_dims, freq_factors, main_stream |  | ||||||
|             ); |  | ||||||
|         } else if (dst->src[0]->type == GGML_TYPE_F16) { |         } else if (dst->src[0]->type == GGML_TYPE_F16) { | ||||||
|             rope_norm_sycl( |             rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, | ||||||
|                 (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, |                            n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, | ||||||
|                 attn_factor, corr_dims, freq_factors, main_stream |                            main_stream); | ||||||
|             ); |  | ||||||
|         } else { |         } else { | ||||||
|             GGML_ABORT("fatal error"); |             GGML_ABORT("fatal error"); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | ||||||
|  |     GGML_SYCL_DEBUG("call %s\n", __func__); | ||||||
|  |     ggml_sycl_op_rope(ctx, dst); | ||||||
|  |     GGML_SYCL_DEBUG("call %s done\n", __func__); | ||||||
|  | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -15,6 +15,6 @@ | |||||||
|  |  | ||||||
| #include "common.hpp" | #include "common.hpp" | ||||||
|  |  | ||||||
| void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); | void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); | ||||||
|  |  | ||||||
| #endif // GGML_SYCL_ROPE_HPP | #endif // GGML_SYCL_ROPE_HPP | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Akarshan Biswas
					Akarshan Biswas