From ed710b36f51ab3f53fa13db15c1685dc8678a32a Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 31 Oct 2025 17:35:00 -0700 Subject: [PATCH] Implement overlap binary operators --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 179 +++++--- .../ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl | 393 +++++++++++++++--- .../ggml-webgpu/wgsl-shaders/binary_head.tmpl | 45 -- tests/test-backend-ops.cpp | 2 +- 4 files changed, 448 insertions(+), 171 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 70e3013537..f6b939e140 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -252,10 +252,10 @@ struct webgpu_context_struct { webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type - webgpu_pipeline add_pipeline[2][2]; // type, inplace - webgpu_pipeline sub_pipeline[2][2]; // type, inplace - webgpu_pipeline mul_pipeline[2][2]; // type, inplace - webgpu_pipeline div_pipeline[2][2]; // type, inplace + webgpu_pipeline add_pipeline[2][2][2]; // type, inplace, overlap + webgpu_pipeline sub_pipeline[2][2][2]; // type, inplace, overlap + webgpu_pipeline mul_pipeline[2][2][2]; // type, inplace, overlap + webgpu_pipeline div_pipeline[2][2][2]; // type, inplace, overlap webgpu_pipeline rms_norm_pipeline[2]; // inplace webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split @@ -677,9 +677,12 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); } +static size_t ggml_webgpu_tensor_align_binding_size(size_t size) { + return (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); +} + static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & - ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); + return ggml_webgpu_tensor_align_binding_size(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t)); } // Used to determine if two tensors are the same for in-place operations @@ -688,6 +691,12 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -870,16 +879,27 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - webgpu_pipeline & pipeline, - bool inplace) { +template +static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + webgpu_pipeline (&pipelines)[a][b][c]) { + int inplace = ggml_webgpu_tensor_equal(src0, dst); + int overlap = ggml_webgpu_tensor_overlap(src0, src1); + webgpu_pipeline pipeline = pipelines[dst->type][inplace][overlap]; + + uint32_t src1_offset = ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type); + if (overlap) { + // when overlapped, bind a single buffer covering both src0 and src1 + // TODO: Do other operations need this? + src1_offset = (uint32_t) ((ggml_webgpu_tensor_offset(src1) - ggml_webgpu_tensor_align_offset(ctx, src0)) / + ggml_type_size(src1->type)); + } std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + src1_offset, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), @@ -894,25 +914,36 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, (uint32_t) src1->ne[3], }; + size_t src0_binding_size = ggml_webgpu_tensor_binding_size(ctx, src0); + if (overlap) { + const uint64_t base_align = ggml_webgpu_tensor_align_offset(ctx, src0); + // assume end of src1 is >= end of src0 + const uint64_t max_end = ggml_webgpu_tensor_offset(src1) + ggml_nbytes(src1); + src0_binding_size = ggml_webgpu_tensor_align_binding_size(max_end - base_align); + } std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src0), .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } + .size = src0_binding_size } }; + uint32_t binding_num = 1; + if (!overlap) { + entries.push_back({ .binding = binding_num, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + binding_num++; + } if (!inplace) { - entries.push_back({ .binding = 2, + entries.push_back({ .binding = binding_num, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -1232,25 +1263,13 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_ADD: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline); case GGML_OP_SUB: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline); case GGML_OP_MUL: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline); case GGML_OP_DIV: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -1700,50 +1719,82 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][0], wgsl_add_f32, + "add_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][0], wgsl_add_f16, + "add_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][0], wgsl_add_f32_inplace, "add_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][0], wgsl_add_f16_inplace, "add_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][1], wgsl_add_f32_overlap, + "add_f32_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][1], + wgsl_add_f32_inplace_overlap, "add_f32_inplace_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][1], wgsl_add_f16_overlap, + "add_f16_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][1], + wgsl_add_f16_inplace_overlap, "add_f16_inplace_overlap", constants); } static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][0], wgsl_sub_f32, + "sub_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][0], wgsl_sub_f16, + "sub_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][0], wgsl_sub_f32_inplace, "sub_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][0], wgsl_sub_f16_inplace, "sub_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][1], wgsl_sub_f32_overlap, + "sub_f32_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][1], + wgsl_sub_f32_inplace_overlap, "sub_f32_inplace_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][1], wgsl_sub_f16_overlap, + "sub_f16_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][1], + wgsl_sub_f16_inplace_overlap, "sub_f16_inplace_overlap", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][0], wgsl_mul_f32, + "mul_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][0], wgsl_mul_f16, + "mul_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][0], wgsl_mul_f32_inplace, "mul_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][0], wgsl_mul_f16_inplace, "mul_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][1], wgsl_mul_f32_overlap, + "mul_f32_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][1], + wgsl_mul_f32_inplace_overlap, "mul_f32_inplace_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][1], wgsl_mul_f16_overlap, + "mul_f16_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][1], + wgsl_mul_f16_inplace_overlap, "mul_f16_inplace_overlap", constants); } static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][0], wgsl_div_f32, + "div_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][0], wgsl_div_f16, + "div_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][0], wgsl_div_f32_inplace, "div_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace, + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][0], wgsl_div_f16_inplace, "div_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][1], wgsl_div_f32_overlap, + "div_f32_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][1], + wgsl_div_f32_inplace_overlap, "div_f32_inplace_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][1], wgsl_div_f16_overlap, + "div_f16_overlap", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][1], + wgsl_div_f16_inplace_overlap, "div_f16_inplace_overlap", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { @@ -2152,9 +2203,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // TODO: Don't enable for WASM builds, they won't have an effect anyways // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; wgpu::DawnTogglesDescriptor deviceTogglesDesc; deviceTogglesDesc.enabledToggles = deviceEnabledToggles; deviceTogglesDesc.enabledToggleCount = 4; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl index 1ce4d83fa8..5143a1bbf1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl @@ -5,15 +5,10 @@ "SHADER_NAME": "add_f32", "REPLS": { "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "+" + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "+", + "PARAMS_BINDING": 3 }, "DECLS": ["NOT_INPLACE"] }, @@ -21,31 +16,87 @@ "SHADER_NAME": "add_f32_inplace", "REPLS": { "TYPE" : "f32", - "OP": "+" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "+", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] }, + { + "SHADER_NAME": "add_f32_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "+", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "add_f32_inplace_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "+", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] + }, + { + "SHADER_NAME": "add_f16", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "+", + "PARAMS_BINDING": 3 + }, + "DECLS": ["NOT_INPLACE"] + }, { "SHADER_NAME": "add_f16_inplace", "REPLS": { "TYPE" : "f16", - "OP": "+" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "+", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] }, + { + "SHADER_NAME": "add_f16_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "+", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "add_f16_inplace_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "+", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] + }, { "SHADER_NAME": "mul_f32", "REPLS": { "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "*" + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "*", + "PARAMS_BINDING": 3 }, "DECLS": ["NOT_INPLACE"] }, @@ -53,31 +104,87 @@ "SHADER_NAME": "mul_f32_inplace", "REPLS": { "TYPE" : "f32", - "OP": "*" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "*", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] }, + { + "SHADER_NAME": "mul_f32_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "*", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "mul_f32_inplace_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "*", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] + }, + { + "SHADER_NAME": "mul_f16", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "*", + "PARAMS_BINDING": 3 + }, + "DECLS": ["NOT_INPLACE"] + }, { "SHADER_NAME": "mul_f16_inplace", "REPLS": { "TYPE" : "f16", - "OP": "*" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "*", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] }, + { + "SHADER_NAME": "mul_f16_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "*", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "mul_f16_inplace_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "*", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] + }, { "SHADER_NAME": "sub_f32", "REPLS": { "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "-" + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "-", + "PARAMS_BINDING": 3 }, "DECLS": ["NOT_INPLACE"] }, @@ -85,31 +192,88 @@ "SHADER_NAME": "sub_f32_inplace", "REPLS": { "TYPE" : "f32", - "OP": "-" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "-", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] }, + { + "SHADER_NAME": "sub_f32_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "-", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "sub_f32_inplace_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "-", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] + }, + { + "SHADER_NAME": "sub_f16", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "-", + "PARAMS_BINDING": 3 + }, + "DECLS": ["NOT_INPLACE"] + }, { "SHADER_NAME": "sub_f16_inplace", "REPLS": { "TYPE" : "f16", - "OP": "-" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "-", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] }, + { + "SHADER_NAME": "sub_f16_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "-", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "sub_f16_inplace_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "-", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] + }, + { "SHADER_NAME": "div_f32", "REPLS": { "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "/" + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "/", + "PARAMS_BINDING": 3 }, "DECLS": ["NOT_INPLACE"] }, @@ -117,17 +281,78 @@ "SHADER_NAME": "div_f32_inplace", "REPLS": { "TYPE" : "f32", - "OP": "/" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "/", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] }, + { + "SHADER_NAME": "div_f32_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "/", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "div_f32_inplace_overlap", + "REPLS": { + "TYPE" : "f32", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "/", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] + }, + { + "SHADER_NAME": "div_f16", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src1", + "DST_BUF": "dst", + "OP": "/", + "PARAMS_BINDING": 3 + }, + "DECLS": ["NOT_INPLACE"] + }, { "SHADER_NAME": "div_f16_inplace", "REPLS": { "TYPE" : "f16", - "OP": "/" + "SRC1_BUF": "src1", + "DST_BUF": "src0", + "OP": "/", + "PARAMS_BINDING": 2 }, "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f16_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "dst", + "OP": "/", + "PARAMS_BINDING": 2 + }, + "DECLS": ["OVERLAP"] + }, + { + "SHADER_NAME": "div_f16_inplace_overlap", + "REPLS": { + "TYPE" : "f16", + "SRC1_BUF": "src0", + "DST_BUF": "src0", + "OP": "/", + "PARAMS_BINDING": 1 + }, + "DECLS": ["INPLACE_OVERLAP"] } ] @@ -137,43 +362,89 @@ #decl(NOT_INPLACE) -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} +@group(0) @binding(1) +var src1: array<{{TYPE}}>; @group(0) @binding(2) var dst: array<{{TYPE}}>; -@group(0) @binding(3) -var params: Params; - #enddecl(NOT_INPLACE) #decl(INPLACE) -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var params: Params; +@group(0) @binding(1) +var src1: array<{{TYPE}}>; #enddecl(INPLACE) -#end(DECLS) +#decl(OVERLAP) +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +#enddecl(OVERLAP) + +#decl(INPLACE_OVERLAP) + +#enddecl(INPLACE_OVERLAP) + +#end(DECLS) #define(SHADER) enable f16; -#include "binary_head.tmpl" +struct Params { + ne: u32, + + // offsets in elements + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + + b_ne0: u32, + b_ne1: u32, + b_ne2: u32, + b_ne3: u32, +}; + +fn src1_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + // handle repetition of b + // index loops back to the beginning and repeats after elements are exhausted = modulo + let b_i0 = a_i0 % params.b_ne0; + let b_i1 = a_i1 % params.b_ne1; + let b_i2 = a_i2 % params.b_ne2; + let b_i3 = a_i3 % params.b_ne3; + + // compute index for position in b's flat array + return b_i0 * params.stride_src1_0 + + b_i1 * params.stride_src1_1 + + b_i2 * params.stride_src1_2 + + b_i3 * params.stride_src1_3; +} @group(0) @binding(0) var src0: array<{{TYPE}}>; -@group(0) @binding(1) -var src1: array<{{TYPE}}>; +@group(0) @binding({{PARAMS_BINDING}}) +var params: Params; DECLS @@ -181,7 +452,7 @@ override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + {{DST_BUF}}[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] {{OP}} {{SRC1_BUF}}[params.offset_src1 + src1_index(gid.x)]; } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl deleted file mode 100644 index 4b254f468d..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +++ /dev/null @@ -1,45 +0,0 @@ -struct Params { - ne: u32, - - // offsets in elements - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - a_ne0: u32, - a_ne1: u32, - a_ne2: u32, - - b_ne0: u32, - b_ne1: u32, - b_ne2: u32, - b_ne3: u32, -}; - -fn src1_index(_i: u32) -> u32 { - var i = _i; - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - // handle repetition of b - // index loops back to the beginning and repeats after elements are exhausted = modulo - let b_i0 = a_i0 % params.b_ne0; - let b_i1 = a_i1 % params.b_ne1; - let b_i2 = a_i2 % params.b_ne2; - let b_i3 = a_i3 % params.b_ne3; - - // compute index for position in b's flat array - return b_i0 * params.stride_src1_0 + - b_i1 * params.stride_src1_1 + - b_i2 * params.stride_src1_2 + - b_i3 * params.stride_src1_3; -} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 04fa1b62d3..0d2cbc530f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4840,7 +4840,7 @@ struct test_moe_expert_reduce : public test_case { std::vector expert_views(n_expert_used); for (int64_t i = 0; i < n_expert_used; ++i) { - expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[2], i * weighted->nb[1]); + expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[1], i * weighted->nb[1]); std::string name = "expert_view_" + std::to_string(i); ggml_set_name(expert_views[i], name.c_str());