mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	SYCL: Add GGML_OP_MEAN operator support (#16009)
* SYCL: Add GGML_OP_MEAN operator support * SYCL: Fix formatting for GGML_OP_MEAN case * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
		| @@ -2151,6 +2151,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * | ||||
|     sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); | ||||
| } | ||||
|  | ||||
| inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|     dpct::queue_ptr main_stream = ctx.stream(); | ||||
|     SYCL_CHECK(ggml_sycl_set_device(ctx.device)); | ||||
|  | ||||
|     const float * src0_dd = static_cast<const float *>(dst->src[0]->data); | ||||
|     float *       dst_dd  = static_cast<float *>(dst->data); | ||||
|  | ||||
|     const int64_t ncols = dst->src[0]->ne[0]; | ||||
|     const int64_t nrows = ggml_nrows(dst->src[0]); | ||||
|  | ||||
|     sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); | ||||
|  | ||||
|     main_stream->parallel_for( | ||||
|         sycl::range<1>(nrows), | ||||
|         [=](sycl::id<1> row) { | ||||
|             dst_dd[row] /= ncols; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
|  | ||||
| inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(dst->type == GGML_TYPE_I32); | ||||
| @@ -3535,6 +3559,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds | ||||
|     ggml_sycl_op_sum_rows(ctx, dst); | ||||
| } | ||||
|  | ||||
| static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | ||||
|     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); | ||||
|     GGML_ASSERT(ggml_is_contiguous(dst->src[0])); | ||||
|     ggml_sycl_op_mean(ctx, dst); | ||||
| } | ||||
|  | ||||
| static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | ||||
|     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); | ||||
|     GGML_ASSERT(ggml_is_contiguous(dst->src[0])); | ||||
| @@ -3784,6 +3814,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg | ||||
|         case GGML_OP_SUM_ROWS: | ||||
|             ggml_sycl_sum_rows(ctx, dst); | ||||
|             break; | ||||
|         case GGML_OP_MEAN: | ||||
|             ggml_sycl_mean(ctx, dst); | ||||
|             break; | ||||
|         case GGML_OP_ARGSORT: | ||||
|             ggml_sycl_argsort(ctx, dst); | ||||
|             break; | ||||
| @@ -4431,6 +4464,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g | ||||
|             return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; | ||||
|         case GGML_OP_SUM: | ||||
|         case GGML_OP_SUM_ROWS: | ||||
|         case GGML_OP_MEAN: | ||||
|         case GGML_OP_ARGSORT: | ||||
|             return ggml_is_contiguous(op->src[0]); | ||||
|         case GGML_OP_POOL_2D: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yael-works
					yael-works