SYCL: Add COUNT_EQUAL operator support (#15991)

* SYCL: Add COUNT_EQUAL operator support (rebased on master)

* SYCL: remove duplicate op_count_equal definition

* tests: remove test_count_equal_typed and use test_count_equal for all cases

* tests: keep only I32 case for COUNT_EQUAL as suggested

* tests: keep only I32 case for COUNT_EQUAL as requested
This commit is contained in:
yael-works
2025-09-15 19:51:35 +03:00
committed by GitHub
parent 28c39da7c6
commit b907255f4b
3 changed files with 19 additions and 0 deletions

View File

@@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_sub(ctx, dst);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_count_equal(ctx, dst);
}
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_mul(ctx, dst);

View File

@@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}
static __dpct_inline__ float op_count_equal(const float a, const float b) {
return (a == b) ? 1.0f : 0.0f;
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}

View File

@@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SUB:
ggml_sycl_sub(ctx, dst);
break;
case GGML_OP_COUNT_EQUAL:
ggml_sycl_count_equal(ctx, dst);
break;
case GGML_OP_ACC:
ggml_sycl_acc(ctx, dst);
break;
@@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_REPEAT: