mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
SYCL: Add GGML_OP_MEAN operator support (#16009)
* SYCL: Add GGML_OP_MEAN operator support * SYCL: Fix formatting for GGML_OP_MEAN case * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@@ -2151,6 +2151,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
|
|||||||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
const int64_t ncols = dst->src[0]->ne[0];
|
||||||
|
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||||
|
|
||||||
|
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||||
|
|
||||||
|
main_stream->parallel_for(
|
||||||
|
sycl::range<1>(nrows),
|
||||||
|
[=](sycl::id<1> row) {
|
||||||
|
dst_dd[row] /= ncols;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||||
@@ -3535,6 +3559,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|||||||
ggml_sycl_op_sum_rows(ctx, dst);
|
ggml_sycl_op_sum_rows(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||||
|
ggml_sycl_op_mean(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||||
@@ -3784,6 +3814,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
ggml_sycl_sum_rows(ctx, dst);
|
ggml_sycl_sum_rows(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_MEAN:
|
||||||
|
ggml_sycl_mean(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
ggml_sycl_argsort(ctx, dst);
|
ggml_sycl_argsort(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@@ -4431,6 +4464,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
|
|||||||
Reference in New Issue
Block a user