mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
ggml webgpu: add support for soft_max, optimize rms_norm (#16357)
* Add inplace softmax * Move rms_norm to split row approach * Update debug for supports_op * clean up debug statements * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -1630,6 +1630,13 @@ extern "C" {
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API void ggml_soft_max_add_sinks(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks);
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
|
||||
#define WEBGPU_WAIT_ANY_BATCH_SIZE 64
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 64
|
||||
#define WEBGPU_NUM_PARAM_BUFS 100
|
||||
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
||||
@@ -35,6 +36,9 @@
|
||||
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
|
||||
// For operations which process a row in parallel, this seems like a reasonable default
|
||||
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
||||
|
||||
/* End Constants */
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
||||
@@ -130,15 +134,16 @@ struct webgpu_context_struct {
|
||||
wgpu::ComputePipeline set_rows_pipeline;
|
||||
wgpu::ComputePipeline get_rows_pipeline[30];
|
||||
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
|
||||
wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
|
||||
wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
|
||||
wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
|
||||
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
|
||||
wgpu::ComputePipeline scale_pipeline[2]; // inplace
|
||||
wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
|
||||
wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
|
||||
wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
|
||||
wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
|
||||
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
|
||||
wgpu::ComputePipeline scale_pipeline[2]; // inplace
|
||||
wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace
|
||||
|
||||
size_t memset_bytes_per_thread;
|
||||
|
||||
@@ -256,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
||||
}),
|
||||
UINT64_MAX);
|
||||
} else {
|
||||
// existing callbacks, wait on them
|
||||
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
|
||||
// WebGPU implementations may limit the number of futures that can be waited on at once,
|
||||
// so wait in batches (64 is what Dawn supports).
|
||||
for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) {
|
||||
size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size());
|
||||
ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX);
|
||||
}
|
||||
ctx->callback_futures.clear();
|
||||
}
|
||||
}
|
||||
@@ -726,9 +735,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src),
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
@@ -912,6 +919,79 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_soft_max(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * dst) {
|
||||
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
|
||||
const int has_sink = (src2 != nullptr);
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
|
||||
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
|
||||
mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
|
||||
mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
(uint32_t) ggml_nelements(dst),
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) src0->ne[2],
|
||||
mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
|
||||
mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
|
||||
*(uint32_t *) dst->op_params, // scale
|
||||
*(uint32_t *) &max_bias,
|
||||
*(uint32_t *) &n_head_log2,
|
||||
*(uint32_t *) &m0,
|
||||
*(uint32_t *) &m1
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
|
||||
};
|
||||
uint32_t binding_num = 1;
|
||||
if (mask_type < 2) {
|
||||
entries.push_back({ .binding = binding_num,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
||||
binding_num++;
|
||||
}
|
||||
if (has_sink) {
|
||||
entries.push_back({ .binding = binding_num,
|
||||
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
|
||||
binding_num++;
|
||||
}
|
||||
if (!inplace) {
|
||||
entries.push_back({ .binding = binding_num,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries,
|
||||
ggml_nrows(dst), ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||
if (ggml_is_empty(node)) {
|
||||
@@ -1237,11 +1317,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
||||
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
|
||||
}
|
||||
|
||||
// The max workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
|
||||
// Workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->max_wg_size_x;
|
||||
constants[0].value = wg_size;
|
||||
return constants;
|
||||
}
|
||||
|
||||
@@ -1309,11 +1389,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
|
||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
|
||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
|
||||
"get_rows_f32_vec", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
|
||||
@@ -1363,7 +1443,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
|
||||
@@ -1375,7 +1455,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
|
||||
@@ -1387,7 +1467,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
|
||||
@@ -1399,7 +1479,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
|
||||
@@ -1411,7 +1491,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
|
||||
@@ -1423,7 +1503,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
|
||||
@@ -1431,7 +1511,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
|
||||
"rope_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
|
||||
@@ -1451,7 +1531,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
// reglu
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
|
||||
wgsl_reglu_f32, "reglu_f32", constants);
|
||||
@@ -1505,13 +1585,43 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
|
||||
"scale_f32_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32,
|
||||
"soft_max_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace,
|
||||
"soft_max_f32_inplace", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink,
|
||||
"soft_max_f32_sink", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1],
|
||||
wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32,
|
||||
"soft_max_f32_mask_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1],
|
||||
wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16,
|
||||
"soft_max_f32_mask_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1],
|
||||
wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0],
|
||||
wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1],
|
||||
wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0],
|
||||
wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1],
|
||||
wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace",
|
||||
constants);
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
GGML_UNUSED(params);
|
||||
|
||||
@@ -1593,6 +1703,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
|
||||
ggml_tensor * src0 = op->src[0];
|
||||
ggml_tensor * src1 = op->src[1];
|
||||
ggml_tensor * src2 = op->src[2];
|
||||
|
||||
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
||||
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
||||
@@ -1623,7 +1734,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
|
||||
supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
|
||||
@@ -1698,13 +1809,25 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
default:
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
if (!supports_op) {
|
||||
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
||||
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
||||
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
||||
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
||||
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
||||
(src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
||||
(src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
|
||||
supports_op = false;
|
||||
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
|
||||
}
|
||||
|
||||
if (!supports_op) {
|
||||
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
|
||||
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
||||
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
||||
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
||||
} else {
|
||||
WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
|
||||
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
||||
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
||||
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
||||
}
|
||||
#endif
|
||||
return supports_op;
|
||||
}
|
||||
|
||||
|
||||
@@ -71,14 +71,14 @@ var<storage, read_write> src: array<f32>;
|
||||
DECLS
|
||||
|
||||
override wg_size: u32;
|
||||
var<workgroup> scratch: array<f32, wg_size>;
|
||||
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
|
||||
// one thread per row
|
||||
var i = gid.x;
|
||||
var i = wid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1);
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
@@ -86,13 +86,38 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
||||
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
|
||||
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
||||
|
||||
var sum = 0.0f;
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
sum += src[i_src_row + j] * src[i_src_row + j];
|
||||
var col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
sum += pow(src[i_src_row + col], 2.0);
|
||||
col += wg_size;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
var offset = wg_size / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
}
|
||||
offset = offset / 2;
|
||||
workgroupBarrier();
|
||||
}
|
||||
sum = scratch[0];
|
||||
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||
update(i_src_row + j, i_dst_row + j, scale);
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_src_row + col, i_dst_row + col, scale);
|
||||
col += wg_size;
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
|
||||
344
ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
Normal file
344
ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
Normal file
@@ -0,0 +1,344 @@
|
||||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32",
|
||||
"DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_inplace",
|
||||
"DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_sink",
|
||||
"DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_sink_inplace",
|
||||
"DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32_sink",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f32",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16_sink",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
|
||||
"REPLS": {
|
||||
"MASK_TYPE" : "f16",
|
||||
},
|
||||
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
|
||||
}
|
||||
]
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(BASE_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(BASE_BINDINGS)
|
||||
|
||||
#decl(BASE_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(BASE_BINDINGS_INPLACE)
|
||||
|
||||
#decl(SINK_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(SINK_BINDINGS)
|
||||
|
||||
#decl(SINK_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(SINK_BINDINGS_INPLACE)
|
||||
|
||||
#decl(MASK_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_BINDINGS)
|
||||
|
||||
#decl(MASK_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_BINDINGS_INPLACE)
|
||||
|
||||
#decl(MASK_SINK_BINDINGS)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_SINK_BINDINGS)
|
||||
|
||||
#decl(MASK_SINK_BINDINGS_INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> sinks: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#enddecl(MASK_SINK_BINDINGS_INPLACE)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
fn inter_value(i: u32) -> f32 {
|
||||
return dst[i];
|
||||
}
|
||||
|
||||
fn update(i: u32, val: f32) {
|
||||
dst[i] = val;
|
||||
}
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
fn inter_value(i: u32) -> f32 {
|
||||
return src[i];
|
||||
}
|
||||
|
||||
fn update(i: u32, val: f32) {
|
||||
src[i] = val;
|
||||
}
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#decl(NO_MASK)
|
||||
fn mask_val(i: u32) -> f32 {
|
||||
return 0.0;
|
||||
}
|
||||
#enddecl(NO_MASK)
|
||||
|
||||
#decl(MASK)
|
||||
fn mask_val(i: u32) -> f32 {
|
||||
return f32(mask[i]);
|
||||
}
|
||||
#enddecl(MASK)
|
||||
|
||||
#decl(NO_SINK)
|
||||
fn lower_max_bound(i2: u32) -> f32 {
|
||||
return -1e30;
|
||||
}
|
||||
|
||||
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
|
||||
return val;
|
||||
}
|
||||
#enddecl(NO_SINK)
|
||||
|
||||
#decl(SINK)
|
||||
fn lower_max_bound(i2: u32) -> f32 {
|
||||
return sinks[params.offset_sinks + i2];
|
||||
}
|
||||
|
||||
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
|
||||
return val + exp(sinks[params.offset_sinks + i2] - max_val);
|
||||
}
|
||||
#enddecl(SINK)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_sinks: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src03: u32,
|
||||
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// shape of src0/dst
|
||||
ne: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
|
||||
// shape of src1
|
||||
ne12: u32,
|
||||
ne13: u32,
|
||||
|
||||
scale: f32,
|
||||
max_bias: f32,
|
||||
n_head_log2: f32,
|
||||
m0: f32,
|
||||
m1: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
DECLS
|
||||
|
||||
const CACHE_SIZE: u32 = 16;
|
||||
|
||||
override wg_size: u32;
|
||||
var<workgroup> scratch: array<f32, wg_size>;
|
||||
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
|
||||
var i = wid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1);
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
let i1 = i % params.ne1;
|
||||
let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
|
||||
let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
|
||||
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
||||
|
||||
let head = f32(i2);
|
||||
let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
|
||||
|
||||
var cache: array<f32, CACHE_SIZE>;
|
||||
|
||||
var max_val = lower_max_bound(i2);
|
||||
var col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
|
||||
max_val = max(max_val, val);
|
||||
if (col < CACHE_SIZE) {
|
||||
cache[col] = val;
|
||||
}
|
||||
col += wg_size;
|
||||
}
|
||||
|
||||
scratch[lid.x] = max_val;
|
||||
workgroupBarrier();
|
||||
var offset = wg_size / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
|
||||
}
|
||||
offset = offset / 2;
|
||||
workgroupBarrier();
|
||||
}
|
||||
let row_max = scratch[0];
|
||||
|
||||
var sum = 0.0f;
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
|
||||
cache[col], col < CACHE_SIZE);
|
||||
let ex = exp(val - row_max);
|
||||
sum += ex;
|
||||
if (col < CACHE_SIZE) {
|
||||
cache[col] = ex;
|
||||
} else {
|
||||
update(i_dst_row + col, ex);
|
||||
}
|
||||
col += wg_size;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
offset = wg_size / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
}
|
||||
offset = offset / 2;
|
||||
workgroupBarrier();
|
||||
}
|
||||
let row_sum = add_sinks(scratch[0], i2, row_max);
|
||||
|
||||
let sum_recip = 1.0 / row_sum;
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
|
||||
col += wg_size;
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
@@ -3852,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext(
|
||||
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
|
||||
}
|
||||
|
||||
void ggml_soft_max_add_sinks(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks) {
|
||||
|
||||
Reference in New Issue
Block a user