Implement overlap binary operators

This commit is contained in:
Reese Levine
2025-10-31 17:35:00 -07:00
parent da5296e1b5
commit ed710b36f5
4 changed files with 448 additions and 171 deletions

View File

@@ -252,10 +252,10 @@ struct webgpu_context_struct {
webgpu_pipeline get_rows_pipeline[30];
webgpu_pipeline get_rows_f32_no_vec_pipeline;
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
webgpu_pipeline add_pipeline[2][2]; // type, inplace
webgpu_pipeline sub_pipeline[2][2]; // type, inplace
webgpu_pipeline mul_pipeline[2][2]; // type, inplace
webgpu_pipeline div_pipeline[2][2]; // type, inplace
webgpu_pipeline add_pipeline[2][2][2]; // type, inplace, overlap
webgpu_pipeline sub_pipeline[2][2][2]; // type, inplace, overlap
webgpu_pipeline mul_pipeline[2][2][2]; // type, inplace, overlap
webgpu_pipeline div_pipeline[2][2][2]; // type, inplace, overlap
webgpu_pipeline rms_norm_pipeline[2]; // inplace
webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace
webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split
@@ -677,9 +677,12 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
}
static size_t ggml_webgpu_tensor_align_binding_size(size_t size) {
return (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
}
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
return ggml_webgpu_tensor_align_binding_size(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t));
}
// Used to determine if two tensors are the same for in-place operations
@@ -688,6 +691,12 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
}
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
@@ -870,16 +879,27 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
}
template <size_t a, size_t b, size_t c>
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
webgpu_pipeline & pipeline,
bool inplace) {
webgpu_pipeline (&pipelines)[a][b][c]) {
int inplace = ggml_webgpu_tensor_equal(src0, dst);
int overlap = ggml_webgpu_tensor_overlap(src0, src1);
webgpu_pipeline pipeline = pipelines[dst->type][inplace][overlap];
uint32_t src1_offset = ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type);
if (overlap) {
// when overlapped, bind a single buffer covering both src0 and src1
// TODO: Do other operations need this?
src1_offset = (uint32_t) ((ggml_webgpu_tensor_offset(src1) - ggml_webgpu_tensor_align_offset(ctx, src0)) /
ggml_type_size(src1->type));
}
std::vector<uint32_t> params = {
(uint32_t) ggml_nelements(dst),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
src1_offset,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
@@ -894,18 +914,29 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
(uint32_t) src1->ne[3],
};
size_t src0_binding_size = ggml_webgpu_tensor_binding_size(ctx, src0);
if (overlap) {
const uint64_t base_align = ggml_webgpu_tensor_align_offset(ctx, src0);
// assume end of src1 is >= end of src0
const uint64_t max_end = ggml_webgpu_tensor_offset(src1) + ggml_nbytes(src1);
src0_binding_size = ggml_webgpu_tensor_align_binding_size(max_end - base_align);
}
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.size = src0_binding_size }
};
uint32_t binding_num = 1;
if (!overlap) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
};
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
binding_num++;
}
if (!inplace) {
entries.push_back({ .binding = 2,
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
@@ -1232,25 +1263,13 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_MUL_MAT:
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
case GGML_OP_ADD:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace);
}
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline);
case GGML_OP_SUB:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace);
}
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline);
case GGML_OP_MUL:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace);
}
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline);
case GGML_OP_DIV:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace);
}
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
@@ -1700,50 +1719,82 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][0], wgsl_add_f32,
"add_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][0], wgsl_add_f16,
"add_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][0], wgsl_add_f32_inplace,
"add_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][0], wgsl_add_f16_inplace,
"add_f16_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][1], wgsl_add_f32_overlap,
"add_f32_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][1],
wgsl_add_f32_inplace_overlap, "add_f32_inplace_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][1], wgsl_add_f16_overlap,
"add_f16_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][1],
wgsl_add_f16_inplace_overlap, "add_f16_inplace_overlap", constants);
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][0], wgsl_sub_f32,
"sub_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][0], wgsl_sub_f16,
"sub_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][0], wgsl_sub_f32_inplace,
"sub_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][0], wgsl_sub_f16_inplace,
"sub_f16_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][1], wgsl_sub_f32_overlap,
"sub_f32_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][1],
wgsl_sub_f32_inplace_overlap, "sub_f32_inplace_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][1], wgsl_sub_f16_overlap,
"sub_f16_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][1],
wgsl_sub_f16_inplace_overlap, "sub_f16_inplace_overlap", constants);
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][0], wgsl_mul_f32,
"mul_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][0], wgsl_mul_f16,
"mul_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][0], wgsl_mul_f32_inplace,
"mul_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][0], wgsl_mul_f16_inplace,
"mul_f16_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][1], wgsl_mul_f32_overlap,
"mul_f32_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][1],
wgsl_mul_f32_inplace_overlap, "mul_f32_inplace_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][1], wgsl_mul_f16_overlap,
"mul_f16_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][1],
wgsl_mul_f16_inplace_overlap, "mul_f16_inplace_overlap", constants);
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][0], wgsl_div_f32,
"div_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][0], wgsl_div_f16,
"div_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][0], wgsl_div_f32_inplace,
"div_f32_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace,
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][0], wgsl_div_f16_inplace,
"div_f16_inplace", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][1], wgsl_div_f32_overlap,
"div_f32_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][1],
wgsl_div_f32_inplace_overlap, "div_f32_inplace_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][1], wgsl_div_f16_overlap,
"div_f16_overlap", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][1],
wgsl_div_f16_inplace_overlap, "div_f16_inplace_overlap", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
@@ -2152,7 +2203,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
// TODO: Don't enable for WASM builds, they won't have an effect anyways
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
// only for native performance?
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init",
"disable_polyfills_on_integer_div_and_mod" };
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
wgpu::DawnTogglesDescriptor deviceTogglesDesc;

View File

@@ -5,15 +5,10 @@
"SHADER_NAME": "add_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "+",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
@@ -21,31 +16,87 @@
"SHADER_NAME": "add_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "+",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "add_f32_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "+",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "add_f32_inplace_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "+",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
},
{
"SHADER_NAME": "add_f16",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "+",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "+",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "add_f16_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "+",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "add_f16_inplace_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "+",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
},
{
"SHADER_NAME": "mul_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "*",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
@@ -53,31 +104,87 @@
"SHADER_NAME": "mul_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "*",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f32_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "*",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "mul_f32_inplace_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "*",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
},
{
"SHADER_NAME": "mul_f16",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "*",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "*",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f16_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "*",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "mul_f16_inplace_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "*",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
},
{
"SHADER_NAME": "sub_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "-",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
@@ -85,31 +192,88 @@
"SHADER_NAME": "sub_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "-",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f32_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "-",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "sub_f32_inplace_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "-",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
},
{
"SHADER_NAME": "sub_f16",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "-",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "-",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f16_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "-",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "sub_f16_inplace_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "-",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
},
{
"SHADER_NAME": "div_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "/",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
@@ -117,17 +281,78 @@
"SHADER_NAME": "div_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "/",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f32_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "/",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "div_f32_inplace_overlap",
"REPLS": {
"TYPE" : "f32",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "/",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
},
{
"SHADER_NAME": "div_f16",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src1",
"DST_BUF": "dst",
"OP": "/",
"PARAMS_BINDING": 3
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
"SRC1_BUF": "src1",
"DST_BUF": "src0",
"OP": "/",
"PARAMS_BINDING": 2
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f16_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "dst",
"OP": "/",
"PARAMS_BINDING": 2
},
"DECLS": ["OVERLAP"]
},
{
"SHADER_NAME": "div_f16_inplace_overlap",
"REPLS": {
"TYPE" : "f16",
"SRC1_BUF": "src0",
"DST_BUF": "src0",
"OP": "/",
"PARAMS_BINDING": 1
},
"DECLS": ["INPLACE_OVERLAP"]
}
]
@@ -137,43 +362,89 @@
#decl(NOT_INPLACE)
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}
@group(0) @binding(2)
var<uniform> params: Params;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
#enddecl(INPLACE)
#end(DECLS)
#decl(OVERLAP)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
#enddecl(OVERLAP)
#decl(INPLACE_OVERLAP)
#enddecl(INPLACE_OVERLAP)
#end(DECLS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
struct Params {
ne: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
};
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
return b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
b_i2 * params.stride_src1_2 +
b_i3 * params.stride_src1_3;
}
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding({{PARAMS_BINDING}})
var<uniform> params: Params;
DECLS
@@ -181,7 +452,7 @@ override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
{{DST_BUF}}[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] {{OP}} {{SRC1_BUF}}[params.offset_src1 + src1_index(gid.x)];
}
}

View File

@@ -1,45 +0,0 @@
struct Params {
ne: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
};
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
return b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
b_i2 * params.stride_src1_2 +
b_i3 * params.stride_src1_3;
}

View File

@@ -4840,7 +4840,7 @@ struct test_moe_expert_reduce : public test_case {
std::vector<ggml_tensor *> expert_views(n_expert_used);
for (int64_t i = 0; i < n_expert_used; ++i) {
expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[2], i * weighted->nb[1]);
expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[1], i * weighted->nb[1]);
std::string name = "expert_view_" + std::to_string(i);
ggml_set_name(expert_views[i], name.c_str());