From dbed61294abbc91b93aef64047bf2fecbf44b48b Mon Sep 17 00:00:00 2001 From: Pavels Zaicenkovs Date: Sun, 16 Nov 2025 22:50:09 +0100 Subject: [PATCH] vulkan: add LOG operation support for F32 and F16 (#17183) * vulkan: add LOG operation support for F32 and F16 Part of #14909. * vulkan: Fix LOG operation types * docs: Update operation support documentation for Vulkan LOG operation * vulkan: fix log_f16 shader * docs: restore missing LOG test cases and regenerate ops.md --- docs/ops.md | 2 +- docs/ops/Vulkan.csv | 8 +++--- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 25 +++++++++++++++++++ ggml/src/ggml-vulkan/vulkan-shaders/log.comp | 17 +++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 +++ 5 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/log.comp diff --git a/docs/ops.md b/docs/ops.md index 901bdf6c4b..4ada4384fc 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -63,7 +63,7 @@ Legend: | IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ | -| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | +| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | | MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index cb34e77e43..290bdd1215 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -8627,7 +8627,7 @@ "Vulkan0","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=129","support","1","yes","Vulkan" "Vulkan0","SQR","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan" "Vulkan0","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","Vulkan" -"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan" +"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SIN","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" @@ -8638,7 +8638,7 @@ "Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SIN","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" @@ -8649,7 +8649,7 @@ "Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan" -"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" +"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" @@ -8660,7 +8660,7 @@ "Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" -"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SIN","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5bdc675cf6..bb3eb977ce 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -629,6 +629,7 @@ struct vk_device_struct { vk_pipeline pipeline_sqrt_f32; vk_pipeline pipeline_sin_f32; vk_pipeline pipeline_cos_f32; + vk_pipeline pipeline_log[2]; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; @@ -3792,6 +3793,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -8126,6 +8129,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cos_f32; } return nullptr; + case GGML_OP_LOG: + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16]; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -8534,6 +8543,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_LOG: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: @@ -8806,6 +8816,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_LOG: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -9414,6 +9425,10 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst)); } +static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst)); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); p.param1 = ggml_get_op_params_f32(dst, 0); @@ -11209,6 +11224,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_LOG: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -11433,6 +11449,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_COS: ggml_vk_cos(ctx, compute_ctx, src0, node); + break; + case GGML_OP_LOG: + ggml_vk_log(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node); @@ -11703,6 +11723,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_LOG: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -13664,6 +13685,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_LOG: + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_OP_ARGSORT: return op->ne[0] <= max_argsort_cols; case GGML_OP_UPSCALE: @@ -14159,6 +14182,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COS) { tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_LOG) { + tensor_clone = ggml_log(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp new file mode 100644 index 0000000000..4aef724e90 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 7623a36208..8a2ce321dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -802,6 +802,9 @@ void process_shaders() { string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});