ggml webgpu: add support for soft_max, optimize rms_norm (#16357)

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Reese Levine
2025-10-02 11:00:31 -07:00
committed by GitHub
parent 34fcc5a4ac
commit ef07a40906
6 changed files with 566 additions and 48 deletions

View File

@@ -1630,6 +1630,13 @@ extern "C" {
float scale,
float max_bias);
GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
float max_bias);
GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);

View File

@@ -28,6 +28,7 @@
/* Constants */
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
#define WEBGPU_WAIT_ANY_BATCH_SIZE 64
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_NUM_PARAM_BUFS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
@@ -35,6 +36,9 @@
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable default
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
/* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
@@ -130,15 +134,16 @@ struct webgpu_context_struct {
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline get_rows_pipeline[30];
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
wgpu::ComputePipeline scale_pipeline[2]; // inplace
wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
wgpu::ComputePipeline scale_pipeline[2]; // inplace
wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace
size_t memset_bytes_per_thread;
@@ -256,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
}),
UINT64_MAX);
} else {
// existing callbacks, wait on them
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
// WebGPU implementations may limit the number of futures that can be waited on at once,
// so wait in batches (64 is what Dawn supports).
for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) {
size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size());
ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX);
}
ctx->callback_futures.clear();
}
}
@@ -726,9 +735,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src),
ggml_op_name(dst->op));
}
@@ -912,6 +919,79 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens
ggml_op_name(dst->op));
}
static void ggml_webgpu_soft_max(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
const int has_sink = (src2 != nullptr);
float max_bias;
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) ggml_nelements(dst),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
*(uint32_t *) dst->op_params, // scale
*(uint32_t *) &max_bias,
*(uint32_t *) &n_head_log2,
*(uint32_t *) &m0,
*(uint32_t *) &m1
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
};
uint32_t binding_num = 1;
if (mask_type < 2) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
binding_num++;
}
if (has_sink) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(src2),
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
binding_num++;
}
if (!inplace) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries,
ggml_nrows(dst), ggml_op_name(dst->op));
}
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) {
@@ -1237,11 +1317,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
}
// The max workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
// Workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->max_wg_size_x;
constants[0].value = wg_size;
return constants;
}
@@ -1309,11 +1389,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
}
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
"get_rows_f32_vec", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
@@ -1363,7 +1443,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
@@ -1375,7 +1455,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
@@ -1387,7 +1467,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
@@ -1399,7 +1479,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
@@ -1411,7 +1491,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
@@ -1423,7 +1503,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
@@ -1431,7 +1511,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
"rope_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
@@ -1451,7 +1531,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
// reglu
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
wgsl_reglu_f32, "reglu_f32", constants);
@@ -1505,13 +1585,43 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
"scale_f32_inplace", constants);
}
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32,
"soft_max_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace,
"soft_max_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink,
"soft_max_f32_sink", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1],
wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32,
"soft_max_f32_mask_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1],
wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16,
"soft_max_f32_mask_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1],
wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0],
wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1],
wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0],
wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1],
wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace",
constants);
}
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@@ -1593,6 +1703,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
ggml_tensor * src0 = op->src[0];
ggml_tensor * src1 = op->src[1];
ggml_tensor * src2 = op->src[2];
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
@@ -1623,7 +1734,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
case GGML_OP_SET_ROWS:
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
break;
case GGML_OP_GET_ROWS:
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
@@ -1698,13 +1809,25 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
default:
break;
}
#ifdef GGML_WEBGPU_DEBUG
if (!supports_op) {
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
(src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
(src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
supports_op = false;
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
}
if (!supports_op) {
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
} else {
WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
}
#endif
return supports_op;
}

View File

@@ -71,14 +71,14 @@ var<storage, read_write> src: array<f32>;
DECLS
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
return;
}
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
// one thread per row
var i = gid.x;
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
@@ -86,13 +86,38 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
var sum = 0.0f;
for (var j: u32 = 0; j < params.ne0; j++) {
sum += src[i_src_row + j] * src[i_src_row + j];
var col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
sum += pow(src[i_src_row + col], 2.0);
col += wg_size;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset = wg_size / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
workgroupBarrier();
}
sum = scratch[0];
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
for (var j: u32 = 0; j < params.ne0; j++) {
update(i_src_row + j, i_dst_row + j, scale);
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
col += wg_size;
}
}
#end(SHADER)

View File

@@ -0,0 +1,344 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "soft_max_f32",
"DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_inplace",
"DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink",
"DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink_inplace",
"DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(BASE_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS)
#decl(BASE_BINDINGS_INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS_INPLACE)
#decl(SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS)
#decl(SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS_INPLACE)
#decl(MASK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS)
#decl(MASK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS_INPLACE)
#decl(MASK_SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS)
#decl(MASK_SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS_INPLACE)
#decl(NOT_INPLACE)
fn inter_value(i: u32) -> f32 {
return dst[i];
}
fn update(i: u32, val: f32) {
dst[i] = val;
}
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn inter_value(i: u32) -> f32 {
return src[i];
}
fn update(i: u32, val: f32) {
src[i] = val;
}
#enddecl(INPLACE)
#decl(NO_MASK)
fn mask_val(i: u32) -> f32 {
return 0.0;
}
#enddecl(NO_MASK)
#decl(MASK)
fn mask_val(i: u32) -> f32 {
return f32(mask[i]);
}
#enddecl(MASK)
#decl(NO_SINK)
fn lower_max_bound(i2: u32) -> f32 {
return -1e30;
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val;
}
#enddecl(NO_SINK)
#decl(SINK)
fn lower_max_bound(i2: u32) -> f32 {
return sinks[params.offset_sinks + i2];
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val + exp(sinks[params.offset_sinks + i2] - max_val);
}
#enddecl(SINK)
#end(DECLS)
#define(SHADER)
enable f16;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_sinks: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of src0/dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
// shape of src1
ne12: u32,
ne13: u32,
scale: f32,
max_bias: f32,
n_head_log2: f32,
m0: f32,
m1: f32,
};
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
const CACHE_SIZE: u32 = 16;
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
let head = f32(i2);
let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
var cache: array<f32, CACHE_SIZE>;
var max_val = lower_max_bound(i2);
var col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
max_val = max(max_val, val);
if (col < CACHE_SIZE) {
cache[col] = val;
}
col += wg_size;
}
scratch[lid.x] = max_val;
workgroupBarrier();
var offset = wg_size / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
}
offset = offset / 2;
workgroupBarrier();
}
let row_max = scratch[0];
var sum = 0.0f;
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
cache[col], col < CACHE_SIZE);
let ex = exp(val - row_max);
sum += ex;
if (col < CACHE_SIZE) {
cache[col] = ex;
} else {
update(i_dst_row + col, ex);
}
col += wg_size;
}
scratch[lid.x] = sum;
workgroupBarrier();
offset = wg_size / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
workgroupBarrier();
}
let row_sum = add_sinks(scratch[0], i2, row_max);
let sum_recip = 1.0 / row_sum;
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
col += wg_size;
}
}
#end(SHADER)

View File

@@ -3852,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
struct ggml_tensor * ggml_soft_max_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
float max_bias) {
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
}
void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks) {