mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
SYCL: Add GGML_OP_MEAN operator support (#16009)
* SYCL: Add GGML_OP_MEAN operator support * SYCL: Fix formatting for GGML_OP_MEAN case * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@@ -2151,6 +2151,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const int64_t ncols = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
|
||||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||
|
||||
main_stream->parallel_for(
|
||||
sycl::range<1>(nrows),
|
||||
[=](sycl::id<1> row) {
|
||||
dst_dd[row] /= ncols;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||
@@ -3535,6 +3559,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
ggml_sycl_op_sum_rows(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_mean(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
@@ -3784,6 +3814,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_sycl_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_sycl_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_sycl_argsort(ctx, dst);
|
||||
break;
|
||||
@@ -4431,6 +4464,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_POOL_2D:
|
||||
|
||||
Reference in New Issue
Block a user