diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index e9201cdc68..ec37a25337 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, #endif #ifdef __cplusplus +#include #include #include @@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size()); } +// Return true if the edges in the graph match expectations. +inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, + int start_idx, + std::initializer_list> edges) { + for (const auto & edge : edges) { + int dst_node = edge[0]; + int src_idx = edge[1]; + int src_node = edge[2]; + if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) { + return false; + } + } + return true; +} + // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3d10aa07b0..50e7922dc6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -385,12 +385,76 @@ static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; -static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; -static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }; +static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, + GGML_OP_RESHAPE }; +static constexpr std::initializer_list topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; +static constexpr std::initializer_list topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] +//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] +//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] +//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ] +//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ] +//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ] +//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] +//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ] +//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ] +//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ] +static constexpr std::initializer_list> topk_moe_early_softmax_norm_edges { + { 1, 0, 0 }, // reshape->src[0] == softmax + { 2, 0, 0 }, // argsort->src[0] == softmax + { 3, 0, 2 }, // view->src[0] == argsort + { 4, 0, 1 }, // get_rows->src[0] == reshape + { 4, 1, 3 }, // get_rows->src[1] == view + { 5, 0, 4 }, // reshape->src[0] == get_rows + { 6, 0, 5 }, // sum_rows->src[0] == reshape + { 7, 0, 6 }, // clamp->src[0] == sum_rows + { 8, 0, 5 }, // div->src[0] == reshape + { 8, 1, 7 }, // div->src[1] == clamp + { 9, 0, 8 }, // reshape->src[0] == div +}; + +// same as early_softmax_norm but ending after the get_rows +static constexpr std::initializer_list> topk_moe_early_softmax_edges { + { 1, 0, 0 }, // reshape->src[0] == softmax + { 2, 0, 0 }, // argsort->src[0] == softmax + { 3, 0, 2 }, // view->src[0] == argsort + { 4, 0, 1 }, // get_rows->src[0] == reshape + { 4, 1, 3 }, // get_rows->src[1] == view +}; + +//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ] +//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ] +//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ] +//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ] +//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ] +//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ] +static constexpr std::initializer_list> topk_moe_late_softmax_edges { + { 1, 0, 0 }, // view->src[0] == argsort + { 2, 1, 1 }, // get_rows->src[1] == view + { 3, 0, 2 }, // reshape->src[0] == get_rows + { 4, 0, 3 }, // soft_max->src[0] == reshape + { 5, 0, 4 }, // reshape->src[0] == soft_max +}; + +enum topk_moe_mode { + TOPK_MOE_EARLY_SOFTMAX, + TOPK_MOE_EARLY_SOFTMAX_NORM, + TOPK_MOE_LATE_SOFTMAX, + TOPK_MOE_COUNT, +}; + +static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { + topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : + num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : + TOPK_MOE_LATE_SOFTMAX; + return mode; +} struct vk_device_struct { std::recursive_mutex mutex; @@ -605,8 +669,7 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; - // [2] is {!norm, norm} - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; std::vector all_pipelines; @@ -954,6 +1017,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_topk_moe_push_constants { uint32_t n_rows; uint32_t n_expert_used; + float clamp_min; + float clamp_max; }; struct vk_op_add_id_push_constants { @@ -3804,8 +3869,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<num_additional_fused_ops) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); - bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; - return ctx->device->pipeline_topk_moe[idx][with_norm]; + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + return ctx->device->pipeline_topk_moe[idx][mode]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -8139,6 +8205,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; } case GGML_OP_ARGSORT: + if (ctx->num_additional_fused_ops) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + GGML_ASSERT(idx < num_topk_moe_pipelines); + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + return ctx->device->pipeline_topk_moe[idx][mode]; + } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); return ctx->device->pipeline_argsort_f32[idx]; @@ -9678,10 +9751,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { - bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; - ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; - ggml_tensor * ids = cgraph->nodes[node_idx + 3]; + ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : + (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : + cgraph->nodes[node_idx + 5]; + ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); @@ -9740,9 +9815,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(d_ids != nullptr); } - vk_op_topk_moe_push_constants pc; + vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; pc.n_expert_used = n_expert_used; + if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { + ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; + pc.clamp_min = ggml_get_op_params_f32(clamp, 0); + pc.clamp_max = ggml_get_op_params_f32(clamp, 1); + } GGML_ASSERT(n_expert_used <= n_experts); @@ -11337,7 +11417,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } } + +#define ENABLE_SYNC_LOGGING 0 + if (need_sync) { +#if ENABLE_SYNC_LOGGING + std::cerr << "sync" << std::endl; +#endif ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); ggml_vk_sync_buffers(ctx, compute_ctx); @@ -11355,6 +11441,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } } +#if ENABLE_SYNC_LOGGING + if (!dryrun) { + for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + auto *n = cgraph->nodes[node_idx + i]; + std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; + if (n->op == GGML_OP_GLU) { + std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; + } + std::cerr << std::endl; + } + } +#endif switch (node->op) { case GGML_OP_REPEAT: @@ -11533,7 +11631,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ARGSORT: - ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + if (ctx->num_additional_fused_ops) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + } break; case GGML_OP_SUM: @@ -12306,31 +12408,28 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st } static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, - int node_idx, bool with_norm) { + int node_idx, topk_moe_mode mode) { - if (with_norm) { - if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) { - return false; - } - for (size_t i = 0; i < topk_moe_norm.size(); ++i) { - if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) { - return false; - } - } - } else { - if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) { - return false; - } - for (size_t i = 0; i < topk_moe.size(); ++i) { - if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) { - return false; - } - } + const ggml_tensor * softmax; + const ggml_tensor * weights; + + switch (mode) { + case TOPK_MOE_EARLY_SOFTMAX_NORM: + softmax = cgraph->nodes[node_idx + 0]; + weights = cgraph->nodes[node_idx + 9]; + break; + case TOPK_MOE_EARLY_SOFTMAX: + softmax = cgraph->nodes[node_idx + 0]; + weights = cgraph->nodes[node_idx + 4]; + break; + case TOPK_MOE_LATE_SOFTMAX: + softmax = cgraph->nodes[node_idx + 4]; + weights = cgraph->nodes[node_idx + 5]; + break; + default: + return false; } - const ggml_tensor * softmax = cgraph->nodes[node_idx + 0]; - const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; - const float * op_params = (const float *)softmax->op_params; float scale = op_params[0]; @@ -12355,60 +12454,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc return false; } - // Check that the nodes don't have any unexpected uses - const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1]; - const ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - const ggml_tensor * view = cgraph->nodes[node_idx + 3]; - const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr; - const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr; - const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr; - const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr; - - // softmax is used by reshape and argsort - if (ggml_node_get_use_count(cgraph, node_idx) != 2 || - reshape1->src[0] != softmax || - argsort->src[0] != softmax) { - return false; - } - // reshape is used by get_rows - if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 || - get_rows->src[0] != reshape1) { - return false; - } - // argsort is used by view - if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 || - view->src[0] != argsort) { - return false; - } - // view is written (via argsort), we can skip checking it - - if (with_norm) { - // get_rows is used by reshape - if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 || - reshape5->src[0] != get_rows) { - return false; - } - - // reshape is used by sum_rows and div - if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 || - sum_rows->src[0] != reshape5 || - div->src[0] != reshape5) { - return false; - } - - // sum_rows is used by div - if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 || - div->src[1] != sum_rows) { - return false; - } - - // div/reshape are written - if (reshape8->src[0] != div) { - return false; - } - } - if (!ctx->device->subgroup_arithmetic || !ctx->device->subgroup_shuffle || !ctx->device->subgroup_require_full_support || @@ -12494,10 +12539,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { - ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { - ctx->num_additional_fused_ops = topk_moe.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && + ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; } } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); @@ -12595,10 +12648,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { - ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { - ctx->num_additional_fused_ops = topk_moe.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && + ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; } } @@ -12730,25 +12791,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * while (first_unused < graph->n_nodes) { std::vector current_set; - // Avoid reordering topk_moe_norm - if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) { - bool is_topk_moe_norm = true; - for (size_t j = 0; j < topk_moe_norm.size(); ++j) { - if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) { - is_topk_moe_norm = false; + // Check for fusion patterns and avoid reordering them + auto const &match_pattern = [&](const std::initializer_list &pattern, int start) -> bool { + if (start + (int)pattern.size() <= graph->n_nodes) { + bool is_pattern = true; + for (size_t j = 0; j < pattern.size(); ++j) { + if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) { + is_pattern = false; + } } + return is_pattern; } - if (is_topk_moe_norm) { - for (size_t j = 0; j < topk_moe_norm.size(); ++j) { + return false; + }; + + auto const &keep_pattern = [&](const std::initializer_list &pattern) -> bool { + if (match_pattern(pattern, first_unused)) { + for (size_t j = 0; j < pattern.size(); ++j) { new_order.push_back(graph->nodes[first_unused + j]); used[first_unused + j] = true; } while (first_unused < graph->n_nodes && used[first_unused]) { first_unused++; } - continue; + return true; } + return false; + }; + + if (keep_pattern(topk_moe_early_softmax_norm)) { + continue; } + if (keep_pattern(topk_moe_early_softmax)) { + continue; + } + if (keep_pattern(topk_moe_late_softmax)) { + continue; + } + // First, grab the next unused node. current_set.push_back(first_unused); @@ -12766,6 +12846,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (is_empty(graph->nodes[j])) { continue; } + // Don't pull forward nodes from fusion patterns + if (match_pattern(topk_moe_early_softmax_norm, j) || + match_pattern(topk_moe_early_softmax, j) || + match_pattern(topk_moe_late_softmax, j)) { + continue; + } bool ok = true; for (int c = first_unused; c < j; ++c) { if (!used[c] && diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 9e56d5f8a3..bc1c278bf4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -11,6 +11,8 @@ layout (push_constant) uniform parameter { uint n_rows; uint n_expert_used; + float clamp_min; + float clamp_max; }; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; @@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 1) const uint n_experts = 512; layout(constant_id = 2) const bool with_norm = true; +layout(constant_id = 3) const bool late_softmax = false; const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; @@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; +const float INFINITY = 1.0 / 0.0; + +// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. +void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) { + float max_val = -INFINITY; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + max_val = max(max_val, vals[i]); + } + } + + max_val = subgroupMax(max_val); + + float sum = 0.f; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + const float val = exp(vals[i] - max_val); + vals[i] = val; + sum += val; + } else { + vals[i] = 0.f; + } + } + + sum = subgroupAdd(sum); + + const float inv_sum = 1.0f / sum; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + vals[i] *= inv_sum; + } + } +} + void main() { const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; if (row >= n_rows) { @@ -35,43 +84,16 @@ void main() { const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; - float logits_r[experts_per_thread]; - - const float INFINITY = 1.0 / 0.0; + float wt[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + gl_LocalInvocationID.x; - logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY; + const uint expert = i + gl_LocalInvocationID.x; + wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } - float max_val = logits_r[0]; - - [[unroll]] - for (int i = 1; i < experts_per_thread; i++) { - const float val = logits_r[i]; - max_val = max(val, max_val); - } - - max_val = subgroupMax(max_val); - - float wt[experts_per_thread]; - float tmp = 0.f; - - [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - const float val = logits_r[i]; - wt[i] = exp(val - max_val); - tmp += wt[i]; - } - - tmp = subgroupAdd(tmp); - - const float inv_sum = 1.0f / tmp; - - [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - wt[i] = wt[i] * inv_sum; + if (!late_softmax) { + softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); } // at this point, each thread holds a portion of softmax, @@ -82,6 +104,11 @@ void main() { float output_weights[experts_per_thread]; + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] = 0.f; + } + for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; uint max_expert = gl_LocalInvocationID.x; @@ -121,6 +148,7 @@ void main() { if (with_norm) { wt_sum = subgroupAdd(wt_sum); + wt_sum = clamp(wt_sum, clamp_min, clamp_max); const float inv_sum = 1.0f / wt_sum; [[unroll]] @@ -129,6 +157,10 @@ void main() { } } + if (late_softmax) { + softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); + } + [[unroll]] for (uint i = 0; i < experts_per_thread; ++i) { uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;