mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
opencl
This commit is contained in:
@@ -5586,8 +5586,8 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
|
|||||||
|
|
||||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
float scale;
|
float scale = ((const float *)(dst->op_params))[0];
|
||||||
memcpy(&scale, dst->op_params, sizeof(scale));
|
float bias = ((const float *)(dst->op_params))[1];
|
||||||
|
|
||||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
@@ -5602,6 +5602,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
|
|||||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
|
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
|
||||||
|
|
||||||
int n = ggml_nelements(dst)/4;
|
int n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ kernel void kernel_scale(
|
|||||||
ulong offset0,
|
ulong offset0,
|
||||||
global float4 * dst,
|
global float4 * dst,
|
||||||
ulong offsetd,
|
ulong offsetd,
|
||||||
float scale
|
float scale,
|
||||||
|
float bias
|
||||||
) {
|
) {
|
||||||
src0 = (global float4*)((global char*)src0 + offset0);
|
src0 = (global float4*)((global char*)src0 + offset0);
|
||||||
dst = (global float4*)((global char*)dst + offsetd);
|
dst = (global float4*)((global char*)dst + offsetd);
|
||||||
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
|
dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user