This commit is contained in:
Xuan Son Nguyen
2025-07-08 23:27:32 +02:00
parent 92a8738452
commit a28df6f00c

View File

@@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
} }
static void scale_f32(const float * x, float * dst, const float scale, const int k, static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
return; return;
} }
dst[i] = scale * x[i]; dst[i] = scale * x[i] + bias;
} }
@@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
static void scale_f32_sycl(const float *x, float *dst, const float scale, static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
const int k, queue_ptr stream) { const int k, queue_ptr stream) {
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
stream->parallel_for( stream->parallel_for(
@@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
scale_f32(x, dst, scale, k, item_ct1); scale_f32(x, dst, scale, bias, k, item_ct1);
}); });
} }
@@ -2318,10 +2318,10 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
const float * src0_dd = static_cast<const float *>(dst->src[0]->data); const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data); float * dst_dd = static_cast<float *>(dst->data);
float scale; float scale = ((const float *)(dst->op_params))[0];
memcpy(&scale, dst->op_params, sizeof(float)); float bias = ((const float *)(dst->op_params))[1];
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream); scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
/* /*
DPCT1010:87: SYCL uses exceptions to report errors and does not use the DPCT1010:87: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code. error codes. The call was replaced with 0. You need to rewrite this code.