GGML WebGPU: Support for ADD, MUL, RMS_NORM, GET_ROWS operators (#16018)

* Add paramater buffer pool, batching of submissions, refactor command building/submission

* Add header for linux builds

* Free staged parameter buffers at once

* Format with clang-format

* Fix thread-safe implementation

* Use device implicit synchronization

* Update workflow to use custom release

* Remove testing branch workflow

* some f32 tests passing

* Disable set_rows until it's implemented

* f32 add all tests passing

* Begin work on set_rows

* Work on set rows

* Add error buffers for reporting unsupported SET_ROWS indices

* Remove extra comments

* Add templated addition, clean up code

* Get addition and multiplication working

* Implement rms_norm

* Add get_rows implementation

* Add new get_rows files

* Refactor use of wg size entry

* Fix compilation

* Try manually unrolled q4_0 quant

* Revert "Try manually unrolled q4_0 quant"

This reverts commit 77f8b96515.

* Move to constant max wg size

* Check for tensor size in supports_op

* Vectorize f32 and change default workgroup size

* Move f32 get_rows from < 4 to % 4 != 0

* fix linter errors

* Add in-place tests

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
This commit is contained in:
Reese Levine
2025-09-17 13:09:40 -07:00
committed by GitHub
parent 0320ac5264
commit d304f459d8
14 changed files with 2673 additions and 1141 deletions

View File

@@ -116,6 +116,10 @@ struct webgpu_context_struct {
wgpu::Queue queue;
wgpu::Limits limits;
// Separate this out from limits since on some Metal systems, the limit returned by
// querying the limits is higher than the actual allowed maximum.
uint32_t max_wg_size_x;
std::recursive_mutex mutex;
webgpu_buf_pool param_buf_pool;
@@ -124,7 +128,15 @@ struct webgpu_context_struct {
wgpu::ComputePipeline memset_pipeline;
wgpu::ComputePipeline mul_mat_pipeline[30][2];
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline get_rows_pipeline[30];
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
wgpu::ComputePipeline cpy_pipeline;
wgpu::ComputePipeline add_pipeline[2];
wgpu::ComputePipeline add_ip_pipeline[2];
wgpu::ComputePipeline mul_pipeline[2];
wgpu::ComputePipeline mul_ip_pipeline[2];
wgpu::ComputePipeline rms_norm_pipeline;
wgpu::ComputePipeline rms_norm_ip_pipeline;
size_t memset_bytes_per_thread;
@@ -232,14 +244,15 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
if (ctx->callback_futures.empty()) {
// no existing callbacks, wait on queue submission
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
}
}),
UINT64_MAX);
ctx->instance.WaitAny(
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
}
}),
UINT64_MAX);
} else {
// existing callbacks, wait on them
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
@@ -286,10 +299,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
// Check for errrors in SET_ROWS operations
for (auto & error_bufs : staged_set_row_error_bufs) {
wgpu::Future f = error_bufs.host_buf.MapAsync(
wgpu::MapMode::Read,
0,
error_bufs.host_buf.GetSize(),
wgpu::CallbackMode::AllowSpontaneous,
wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
@@ -311,10 +321,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
wgpu::MapMode mode,
size_t offset,
size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(mode,
offset,
size,
wgpu::CallbackMode::AllowSpontaneous,
ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
@@ -351,7 +358,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
bool submit_and_wait = false) {
const char * bind_group_label = nullptr,
bool submit_and_wait = false) {
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
@@ -372,6 +380,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
bind_group_desc.entryCount = bind_group_entries.size();
bind_group_desc.entries = bind_group_entries.data();
if (bind_group_label) {
bind_group_desc.label = bind_group_label;
}
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
@@ -415,9 +426,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
};
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread;
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
}
/** End WebGPU Actions */
@@ -461,26 +472,26 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
}
// Used to determine if two tensors are the same for in-place operations
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = { ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3] };
std::vector<uint32_t> params = {
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@@ -493,9 +504,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
@@ -509,27 +520,21 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
error_bufs.host_buf.Unmap();
}
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]),
(uint32_t) (idx->ne[2]) };
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@@ -547,13 +552,55 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
};
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
ctx->staged_set_row_error_bufs.push_back(error_bufs);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of dst
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src),
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(idx),
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size;
wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type];
if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) {
pipeline = ctx->get_rows_f32_no_vec_pipeline;
}
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -593,7 +640,104 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
uint32_t wg_x =
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
ggml_op_name(dst->op));
}
static void ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
wgpu::ComputePipeline & pipeline,
bool in_place) {
std::vector<uint32_t> params = {
(uint32_t) ggml_nelements(dst),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
};
if (!in_place) {
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool in_place = ggml_webgpu_tensor_equal(src, dst);
uint32_t eps;
memcpy(&eps, dst->op_params, sizeof(float));
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
};
if (!in_place) {
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
}
params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
if (!in_place) {
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
}
params.push_back((uint32_t) src->ne[0]);
params.push_back((uint32_t) src->ne[1]);
params.push_back((uint32_t) src->ne[2]);
params.push_back((uint32_t) src->ne[3]);
params.push_back(eps); // epsilon, will be bitcast to float in shader
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src),
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
};
if (!in_place) {
entries.push_back({ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
wgpu::ComputePipeline pipeline;
if (in_place) {
pipeline = ctx->rms_norm_ip_pipeline;
} else {
pipeline = ctx->rms_norm_pipeline;
}
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
// Returns true if node has enqueued work into the queue, false otherwise
@@ -615,20 +759,34 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
case GGML_OP_RESHAPE:
return false;
case GGML_OP_CPY:
{
ggml_webgpu_cpy(ctx, src0, node);
break;
}
ggml_webgpu_cpy(ctx, src0, node);
break;
case GGML_OP_SET_ROWS:
{
ggml_webgpu_set_rows(ctx, src0, src1, node);
break;
}
ggml_webgpu_set_rows(ctx, src0, src1, node);
break;
case GGML_OP_GET_ROWS:
ggml_webgpu_get_rows(ctx, src0, src1, node);
break;
case GGML_OP_MUL_MAT:
{
ggml_webgpu_mul_mat(ctx, src0, src1, node);
break;
ggml_webgpu_mul_mat(ctx, src0, src1, node);
break;
case GGML_OP_ADD:
if (ggml_webgpu_tensor_equal(src0, node)) {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
} else {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
}
break;
case GGML_OP_MUL:
if (ggml_webgpu_tensor_equal(src0, node)) {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
} else {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
}
break;
case GGML_OP_RMS_NORM:
ggml_webgpu_rms_norm(ctx, src0, node);
break;
default:
return false;
}
@@ -731,8 +889,8 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
}
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
remaining_size);
} else {
// wait for WriteBuffer to complete
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
@@ -766,11 +924,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy();
}
ggml_webgpu_create_buffer(device,
webgpu_ctx->get_tensor_staging_buf,
final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"get_tensor_staging_buf");
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
}
// Copy the data from the buffer to the staging buffer
@@ -824,8 +979,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
buf,
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
"allocated_buffer");
@@ -890,9 +1044,17 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
}
// The max workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->max_wg_size_x;
return constants;
}
static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
size_t max_wg_size = webgpu_ctx->max_wg_size_x;
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled
webgpu_ctx->memset_bytes_per_thread =
@@ -906,109 +1068,142 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_mul_mat_f32_f32,
"mul_mat_f32_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
wgsl_mul_mat_f16_f16,
"mul_mat_f16_f16");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
wgsl_mul_mat_f16_f32,
"mul_mat_f16_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
wgsl_mul_mat_q4_0_f32,
"mul_mat_q4_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
wgsl_mul_mat_q4_1_f32,
"mul_mat_q4_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
wgsl_mul_mat_q5_0_f32,
"mul_mat_q5_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
wgsl_mul_mat_q5_1_f32,
"mul_mat_q5_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
wgsl_mul_mat_q8_0_f32,
"mul_mat_q8_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
wgsl_mul_mat_q2_k_f32,
"mul_mat_q2_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
wgsl_mul_mat_q3_k_f32,
"mul_mat_q3_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
wgsl_mul_mat_q4_k_f32,
"mul_mat_q4_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
wgsl_mul_mat_q5_k_f32,
"mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
wgsl_mul_mat_q6_k_f32,
"mul_mat_q6_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xxs_f32,
"mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xs_f32,
"mul_mat_iq2_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
wgsl_mul_mat_iq2_s_f32,
"mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq3_xxs_f32,
"mul_mat_iq3_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
wgsl_mul_mat_iq3_s_f32,
"mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
wgsl_mul_mat_iq1_s_f32,
"mul_mat_iq1_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
wgsl_mul_mat_iq1_m_f32,
"mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
wgsl_mul_mat_iq4_nl_f32,
"mul_mat_iq4_nl_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq4_xs_f32,
"mul_mat_iq4_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
}
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
}
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
"get_rows_f32_vec", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
"get_rows_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
"get_rows_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
"get_rows_i32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
"get_rows_q4_0", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
"get_rows_q4_1", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
"get_rows_q5_0", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
"get_rows_q5_1", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
"get_rows_q8_0", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
"get_rows_q2_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
"get_rows_q3_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
"get_rows_q4_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
"get_rows_q5_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
"get_rows_q6_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s,
"get_rows_iq2_s", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s,
"get_rows_iq3_s", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s,
"get_rows_iq1_s", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m,
"get_rows_iq1_m", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
"add_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
"add_in_place_f16", constants);
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
"mul_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
"mul_in_place_f16", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
"rms_norm_in_place", constants);
}
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -1058,24 +1253,77 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
}
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_UNUSED(dev);
static bool ggml_webgpu_supported_qtype(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
return true;
default:
return false;
}
}
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
webgpu_context webgpu_ctx = ctx->webgpu_ctx;
ggml_tensor * src0 = op->src[0];
ggml_tensor * src1 = op->src[1];
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
(src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
return false;
}
bool supports_op = false;
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RESHAPE:
return true;
supports_op = true;
break;
case GGML_OP_ADD:
case GGML_OP_MUL:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
(op->src[1]->type == op->type);
break;
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
break;
case GGML_OP_GET_ROWS:
if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
supports_op = (op->type == GGML_TYPE_F32);
}
break;
case GGML_OP_MUL_MAT:
{
switch (op->src[1]->type) {
case GGML_TYPE_F16:
return op->src[0]->type == GGML_TYPE_F16;
supports_op = (op->src[0]->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
switch (op->src[0]->type) {
case GGML_TYPE_F32:
@@ -1099,17 +1347,30 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
return true;
supports_op = true;
break;
default:
return false;
break;
}
default:
return false;
break;
}
break;
}
case GGML_OP_RMS_NORM:
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
break;
default:
return false;
break;
}
#ifdef GGML_WEBGPU_DEBUG
if (!supports_op) {
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
}
#endif
return supports_op;
}
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
@@ -1155,18 +1416,20 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {};
ctx->instance.WaitAny(
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous,
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
ctx->adapter = std::move(adapter);
}), UINT64_MAX);
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
&options, wgpu::CallbackMode::AllowSpontaneous,
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
ctx->adapter = std::move(adapter);
}),
UINT64_MAX);
GGML_ASSERT(ctx->adapter != nullptr);
ctx->adapter.GetLimits(&ctx->limits);
ctx->max_wg_size_x = 288; // default value
wgpu::AdapterInfo info{};
ctx->adapter.GetInfo(&info);
@@ -1182,21 +1445,21 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR(
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR(
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
&dev_desc,
wgpu::CallbackMode::AllowSpontaneous,
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
std::string(message).c_str());
return;
}
ctx->device = std::move(device);
@@ -1208,34 +1471,28 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ctx->queue = ctx->device.GetQueue();
// Create buffer pool for shader parameters
ctx->param_buf_pool.init(ctx->device,
WEBGPU_NUM_PARAM_BUFS,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
ctx->set_rows_error_buf_pool.init(ctx->device,
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
ggml_webgpu_init_memset_pipeline(ctx);
ggml_webgpu_init_mul_mat_pipeline(ctx);
ggml_webgpu_init_set_rows_pipeline(ctx);
ggml_webgpu_init_get_rows_pipeline(ctx);
ggml_webgpu_init_cpy_pipeline(ctx);
ggml_webgpu_init_add_pipeline(ctx);
ggml_webgpu_init_mul_pipeline(ctx);
ggml_webgpu_init_rms_norm_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(ctx->device,
ctx->debug_host_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"debug_host_buf");
ggml_webgpu_create_buffer(ctx->device,
ctx->debug_dev_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
"debug_dev_buf");
ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
#endif
static ggml_backend_webgpu_device_context device_ctx;
@@ -1246,12 +1503,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
GGML_LOG_INFO(
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
"device_desc: %s\n",
info.vendorID,
std::string(info.vendor).c_str(),
std::string(info.architecture).c_str(),
info.deviceID,
std::string(info.device).c_str(),
std::string(info.description).c_str());
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
std::string(info.device).c_str(), std::string(info.description).c_str());
// See GGML Backend Device Interface section
static ggml_backend_device device = {