From f2e187c7f26cdf0f4a9284db8ecbb2d74ffb7213 Mon Sep 17 00:00:00 2001 From: neha-ha <137219201+neha-ha@users.noreply.github.com> Date: Mon, 27 Oct 2025 14:48:39 -0500 Subject: [PATCH] Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas Co-authored-by: Neha Abbas Co-authored-by: Reese Levine --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 25 +++++++--- .../{set_rows.wgsl => set_rows.tmpl.wgsl} | 46 ++++++++++++++++--- 2 files changed, 58 insertions(+), 13 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{set_rows.wgsl => set_rows.tmpl.wgsl} (68%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b4558a9e3f..353c7729bd 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -248,7 +248,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline set_rows_pipeline; + webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized) webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type @@ -766,10 +766,21 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } }; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + size_t max_wg_size = ctx->max_wg_size_x; - return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs); + int vectorized = src->ne[0] % 4 == 0; + webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized]; + // if not evenly divisble by 4, use the non-vectorized version + uint32_t threads; + if (vectorized) { + threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); + } else { + threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; + } + + uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -1620,8 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16, + "set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec, + "set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl similarity index 68% rename from ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl index 3567713dc2..4a6d819d3b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl @@ -1,13 +1,38 @@ +#define(VARIANTS) + +[ + { + "SHADER_SUFFIX": "f16_vec", + "REPLS": { + "TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE": 4 + } + }, + { + "SHADER_SUFFIX": "f16", + "REPLS": { + "TYPE" : "f32", + "DST_TYPE": "f16", + "VEC_SIZE": 1 + } + } +] + +#end(VARIANTS) + +#define(SHADER) + enable f16; @group(0) @binding(0) -var src: array; +var src: array<{{TYPE}}>; @group(0) @binding(1) var idx: array; @group(0) @binding(2) -var dst: array; +var dst: array<{{DST_TYPE}}>; @group(0) @binding(3) var error: atomic; @@ -47,10 +72,14 @@ var params: Params; override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) { return; } - var i = gid.x; + + // getting the row from gid + let elems_per_row = params.ne0 / {{VEC_SIZE}}; + var i = gid.x / elems_per_row; + let i_src3 = i / (params.ne2 * params.n_rows); i = i % (params.ne2 * params.n_rows); @@ -75,7 +104,10 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - for (var i: u32 = 0; i < params.ne0; i++) { - dst[i_dst_row + i] = f16(src[i_src_row + i]); - } + // starts at what element of that row? + let col_idx = (gid.x % elems_per_row); + dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]); } + +#end(SHADER) +