From 50f88fc4caf5790e5902e5a63107b364f69f83a4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Jun 2025 11:21:26 +0200 Subject: [PATCH] ggml : add ggml_scale_bias --- ggml/include/ggml.h | 13 +++++++++++++ ggml/src/ggml-cpu/ops.cpp | 13 +++++++++---- ggml/src/ggml-metal/ggml-metal.m | 5 +++-- ggml/src/ggml-metal/ggml-metal.metal | 6 ++++-- ggml/src/ggml.c | 28 +++++++++++++++++++++++----- tests/test-backend-ops.cpp | 11 +++++++---- 6 files changed, 59 insertions(+), 17 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9c4e24023b..236ac52eb3 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1185,6 +1185,19 @@ extern "C" { struct ggml_tensor * a, float s); + // x = s * a + b + GGML_API struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + + GGML_API struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8531baf6c5..bc61080797 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3937,9 +3937,11 @@ static void ggml_compute_forward_scale_f32( GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - // scale factor - float v; - memcpy(&v, dst->op_params, sizeof(float)); + float s; // scale factor + float b; // bias + + memcpy(&s, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&b, (float *) dst->op_params + 1, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -3963,7 +3965,10 @@ static void ggml_compute_forward_scale_f32( // src0 is same shape as dst => same indices memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); + if (b != 0.0f) { + ggml_vec_acc1_f32(nc, (float *) ((char *) dst->data + i1*nb1), b); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d8d30cc0b4..69b8a268bf 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2189,8 +2189,8 @@ static bool ggml_metal_encode_node( { GGML_ASSERT(ggml_is_contiguous(src0)); - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + float scale = ((const float *)(dst->op_params))[0]; + float bias = ((const float *)(dst->op_params))[1]; int64_t n = ggml_nelements(dst); @@ -2207,6 +2207,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + [encoder setBytes:&bias length:sizeof(bias) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5f004a856b..ae012b1c79 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -810,16 +810,18 @@ kernel void kernel_scale( device const float * src0, device float * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_clamp( diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee605977f3..e77d33fc7a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2858,12 +2858,14 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, float s, + float b, bool inplace) { GGML_ASSERT(ggml_is_padded_1d(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_set_op_params(result, &s, sizeof(s)); + float params[2] = { s, b }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_SCALE; result->src[0] = a; @@ -2875,14 +2877,30 @@ struct ggml_tensor * ggml_scale( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, false); + return ggml_scale_impl(ctx, a, s, 0.0, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, true); + return ggml_scale_impl(ctx, a, s, 0.0, true); +} + +struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, false); +} + +struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, true); } // ggml_set @@ -5472,7 +5490,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MEAN: { if (src0_needs_grads) { - ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false)); } } break; case GGML_OP_REPEAT: { @@ -5549,7 +5567,7 @@ static void ggml_compute_backward( if (src0_needs_grads) { float s; memcpy(&s, tensor->op_params, sizeof(float)); - ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false)); } } break; case GGML_OP_SET: { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 615c2dc008..d1b2ff10d1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1655,22 +1655,24 @@ struct test_scale : public test_case { const ggml_type type; const std::array ne; float scale; + float bias; std::string vars() override { - return VARS_TO_STR3(type, ne, scale); + return VARS_TO_STR4(type, ne, scale, bias); } test_scale(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, - float scale = 2.0f) - : type(type), ne(ne), scale(scale) {} + float scale = 2.0f, + float bias = 0.0f) + : type(type), ne(ne), scale(scale), bias(bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_scale(ctx, a, scale); + ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); ggml_set_name(out, "out"); return out; @@ -4209,6 +4211,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_silu_back()); for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {