mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
SYCL: Remove misleading ggml_sycl_op_flatten function (#12387)
* SYCL: Remove misleading ggml_sycl_op_flatten function * remove trailing whitespace * Fix L2 norm from rebase * remove try catch block from element_wise.cpp * remove comment from common.hp * ggml-sycl.cpp: Add try catch sycl::exception block in compute_forward * norm.cpp: remove try catch exception block
This commit is contained in:
@@ -192,18 +192,15 @@ static void rope_neox_sycl(
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rope(
|
||||
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t ne01 = dst->src[0]->ne[1];
|
||||
const int64_t nr = ggml_nrows(dst->src[0]);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
@@ -228,49 +225,47 @@ void ggml_sycl_op_rope(
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1_dd;
|
||||
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
if (dst->src[2] != nullptr) {
|
||||
freq_factors = (const float *) dst->src[2]->data;
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
// compute
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_neox_sycl(
|
||||
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl(
|
||||
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_norm_sycl(
|
||||
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl(
|
||||
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user