Minor set_rows optimization (#4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
neha-ha
2025-10-27 14:48:39 -05:00
committed by GitHub
parent b566811913
commit f2e187c7f2
2 changed files with 58 additions and 13 deletions

View File

@@ -248,7 +248,7 @@ struct webgpu_context_struct {
webgpu_pipeline memset_pipeline;
webgpu_pipeline mul_mat_pipeline[30][2];
webgpu_pipeline set_rows_pipeline;
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized)
webgpu_pipeline get_rows_pipeline[30];
webgpu_pipeline get_rows_f32_no_vec_pipeline;
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
@@ -766,10 +766,21 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
};
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
size_t max_wg_size = ctx->max_wg_size_x;
return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs);
int vectorized = src->ne[0] % 4 == 0;
webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized];
// if not evenly divisble by 4, use the non-vectorized version
uint32_t threads;
if (vectorized) {
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
} else {
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
}
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
@@ -1620,8 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16,
"set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec,
"set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
}
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {

View File

@@ -1,13 +1,38 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f16_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f16>",
"VEC_SIZE": 4
}
},
{
"SHADER_SUFFIX": "f16",
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f16",
"VEC_SIZE": 1
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
var<storage, read_write> src: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> idx: array<u32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f16>;
var<storage, read_write> dst: array<{{DST_TYPE}}>;
@group(0) @binding(3)
var<storage, read_write> error: atomic<u32>;
@@ -47,10 +72,14 @@ var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
return;
}
var i = gid.x;
// getting the row from gid
let elems_per_row = params.ne0 / {{VEC_SIZE}};
var i = gid.x / elems_per_row;
let i_src3 = i / (params.ne2 * params.n_rows);
i = i % (params.ne2 * params.n_rows);
@@ -75,7 +104,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
for (var i: u32 = 0; i < params.ne0; i++) {
dst[i_dst_row + i] = f16(src[i_src_row + i]);
}
// starts at what element of that row?
let col_idx = (gid.x % elems_per_row);
dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
}
#end(SHADER)