mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
vulkan: Update topk_moe fusion to handle gpt's late softmax (#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax Based on #16649. * Add ggml_check_edges * Add sync logging to show fusion effects * handle clamp added in #16655 * Update ggml/src/ggml-impl.h Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
#include <array>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
|||||||
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return true if the edges in the graph match expectations.
|
||||||
|
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
|
||||||
|
int start_idx,
|
||||||
|
std::initializer_list<std::array<int, 3>> edges) {
|
||||||
|
for (const auto & edge : edges) {
|
||||||
|
int dst_node = edge[0];
|
||||||
|
int src_idx = edge[1];
|
||||||
|
int src_node = edge[2];
|
||||||
|
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// expose GGUF internals for test code
|
// expose GGUF internals for test code
|
||||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||||
|
|||||||
@@ -385,12 +385,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
|
|||||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||||
|
|
||||||
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||||
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
GGML_OP_RESHAPE };
|
||||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||||
|
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||||
|
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||||
|
|
||||||
|
//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
|
||||||
|
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||||
|
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||||
|
//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
|
||||||
|
//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
|
||||||
|
//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
|
||||||
|
//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
|
||||||
|
//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
|
||||||
|
//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
|
||||||
|
//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
|
||||||
|
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
|
||||||
|
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
||||||
|
{ 2, 0, 0 }, // argsort->src[0] == softmax
|
||||||
|
{ 3, 0, 2 }, // view->src[0] == argsort
|
||||||
|
{ 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||||
|
{ 4, 1, 3 }, // get_rows->src[1] == view
|
||||||
|
{ 5, 0, 4 }, // reshape->src[0] == get_rows
|
||||||
|
{ 6, 0, 5 }, // sum_rows->src[0] == reshape
|
||||||
|
{ 7, 0, 6 }, // clamp->src[0] == sum_rows
|
||||||
|
{ 8, 0, 5 }, // div->src[0] == reshape
|
||||||
|
{ 8, 1, 7 }, // div->src[1] == clamp
|
||||||
|
{ 9, 0, 8 }, // reshape->src[0] == div
|
||||||
|
};
|
||||||
|
|
||||||
|
// same as early_softmax_norm but ending after the get_rows
|
||||||
|
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
||||||
|
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
||||||
|
{ 2, 0, 0 }, // argsort->src[0] == softmax
|
||||||
|
{ 3, 0, 2 }, // view->src[0] == argsort
|
||||||
|
{ 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||||
|
{ 4, 1, 3 }, // get_rows->src[1] == view
|
||||||
|
};
|
||||||
|
|
||||||
|
//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
|
||||||
|
//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
|
||||||
|
//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
|
||||||
|
//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
|
||||||
|
//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
|
||||||
|
//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
|
||||||
|
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
|
||||||
|
{ 1, 0, 0 }, // view->src[0] == argsort
|
||||||
|
{ 2, 1, 1 }, // get_rows->src[1] == view
|
||||||
|
{ 3, 0, 2 }, // reshape->src[0] == get_rows
|
||||||
|
{ 4, 0, 3 }, // soft_max->src[0] == reshape
|
||||||
|
{ 5, 0, 4 }, // reshape->src[0] == soft_max
|
||||||
|
};
|
||||||
|
|
||||||
|
enum topk_moe_mode {
|
||||||
|
TOPK_MOE_EARLY_SOFTMAX,
|
||||||
|
TOPK_MOE_EARLY_SOFTMAX_NORM,
|
||||||
|
TOPK_MOE_LATE_SOFTMAX,
|
||||||
|
TOPK_MOE_COUNT,
|
||||||
|
};
|
||||||
|
|
||||||
|
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
||||||
|
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
||||||
|
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
||||||
|
TOPK_MOE_LATE_SOFTMAX;
|
||||||
|
return mode;
|
||||||
|
}
|
||||||
|
|
||||||
struct vk_device_struct {
|
struct vk_device_struct {
|
||||||
std::recursive_mutex mutex;
|
std::recursive_mutex mutex;
|
||||||
@@ -605,8 +669,7 @@ struct vk_device_struct {
|
|||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||||
|
|
||||||
// [2] is {!norm, norm}
|
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
|
||||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
|
||||||
|
|
||||||
std::vector<vk_pipeline_ref> all_pipelines;
|
std::vector<vk_pipeline_ref> all_pipelines;
|
||||||
|
|
||||||
@@ -954,6 +1017,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
|||||||
struct vk_op_topk_moe_push_constants {
|
struct vk_op_topk_moe_push_constants {
|
||||||
uint32_t n_rows;
|
uint32_t n_rows;
|
||||||
uint32_t n_expert_used;
|
uint32_t n_expert_used;
|
||||||
|
float clamp_min;
|
||||||
|
float clamp_max;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_add_id_push_constants {
|
struct vk_op_add_id_push_constants {
|
||||||
@@ -3804,8 +3869,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
@@ -8083,8 +8149,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||||||
if (ctx->num_additional_fused_ops) {
|
if (ctx->num_additional_fused_ops) {
|
||||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||||
return ctx->device->pipeline_topk_moe[idx][with_norm];
|
return ctx->device->pipeline_topk_moe[idx][mode];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||||
@@ -8139,6 +8205,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
if (ctx->num_additional_fused_ops) {
|
||||||
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||||
|
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||||
|
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||||
|
return ctx->device->pipeline_topk_moe[idx][mode];
|
||||||
|
}
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||||
return ctx->device->pipeline_argsort_f32[idx];
|
return ctx->device->pipeline_argsort_f32[idx];
|
||||||
@@ -9678,10 +9751,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
|
|
||||||
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
||||||
|
|
||||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||||
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||||
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
|
||||||
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
|
||||||
|
cgraph->nodes[node_idx + 5];
|
||||||
|
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
|
||||||
|
|
||||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||||
@@ -9740,9 +9815,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||||||
GGML_ASSERT(d_ids != nullptr);
|
GGML_ASSERT(d_ids != nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
vk_op_topk_moe_push_constants pc;
|
vk_op_topk_moe_push_constants pc {};
|
||||||
pc.n_rows = n_rows;
|
pc.n_rows = n_rows;
|
||||||
pc.n_expert_used = n_expert_used;
|
pc.n_expert_used = n_expert_used;
|
||||||
|
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
||||||
|
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
||||||
|
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||||
|
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(n_expert_used <= n_experts);
|
GGML_ASSERT(n_expert_used <= n_experts);
|
||||||
|
|
||||||
@@ -11337,7 +11417,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define ENABLE_SYNC_LOGGING 0
|
||||||
|
|
||||||
if (need_sync) {
|
if (need_sync) {
|
||||||
|
#if ENABLE_SYNC_LOGGING
|
||||||
|
std::cerr << "sync" << std::endl;
|
||||||
|
#endif
|
||||||
ctx->unsynced_nodes_written.clear();
|
ctx->unsynced_nodes_written.clear();
|
||||||
ctx->unsynced_nodes_read.clear();
|
ctx->unsynced_nodes_read.clear();
|
||||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||||
@@ -11355,6 +11441,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#if ENABLE_SYNC_LOGGING
|
||||||
|
if (!dryrun) {
|
||||||
|
for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
||||||
|
auto *n = cgraph->nodes[node_idx + i];
|
||||||
|
std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
|
||||||
|
if (n->op == GGML_OP_GLU) {
|
||||||
|
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
|
||||||
|
}
|
||||||
|
std::cerr << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
@@ -11533,7 +11631,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
if (ctx->num_additional_fused_ops) {
|
||||||
|
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
||||||
|
} else {
|
||||||
|
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
@@ -12306,31 +12408,28 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||||
int node_idx, bool with_norm) {
|
int node_idx, topk_moe_mode mode) {
|
||||||
|
|
||||||
if (with_norm) {
|
const ggml_tensor * softmax;
|
||||||
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
const ggml_tensor * weights;
|
||||||
return false;
|
|
||||||
}
|
switch (mode) {
|
||||||
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
case TOPK_MOE_EARLY_SOFTMAX_NORM:
|
||||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
softmax = cgraph->nodes[node_idx + 0];
|
||||||
return false;
|
weights = cgraph->nodes[node_idx + 9];
|
||||||
}
|
break;
|
||||||
}
|
case TOPK_MOE_EARLY_SOFTMAX:
|
||||||
} else {
|
softmax = cgraph->nodes[node_idx + 0];
|
||||||
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
weights = cgraph->nodes[node_idx + 4];
|
||||||
return false;
|
break;
|
||||||
}
|
case TOPK_MOE_LATE_SOFTMAX:
|
||||||
for (size_t i = 0; i < topk_moe.size(); ++i) {
|
softmax = cgraph->nodes[node_idx + 4];
|
||||||
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
weights = cgraph->nodes[node_idx + 5];
|
||||||
return false;
|
break;
|
||||||
}
|
default:
|
||||||
}
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
|
||||||
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
|
||||||
|
|
||||||
const float * op_params = (const float *)softmax->op_params;
|
const float * op_params = (const float *)softmax->op_params;
|
||||||
|
|
||||||
float scale = op_params[0];
|
float scale = op_params[0];
|
||||||
@@ -12355,60 +12454,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the nodes don't have any unexpected uses
|
|
||||||
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
|
||||||
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
|
||||||
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
|
||||||
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
|
||||||
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
|
||||||
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
|
||||||
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
|
||||||
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
|
||||||
|
|
||||||
// softmax is used by reshape and argsort
|
|
||||||
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
|
||||||
reshape1->src[0] != softmax ||
|
|
||||||
argsort->src[0] != softmax) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// reshape is used by get_rows
|
|
||||||
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
|
||||||
get_rows->src[0] != reshape1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// argsort is used by view
|
|
||||||
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
|
||||||
view->src[0] != argsort) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// view is written (via argsort), we can skip checking it
|
|
||||||
|
|
||||||
if (with_norm) {
|
|
||||||
// get_rows is used by reshape
|
|
||||||
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
|
||||||
reshape5->src[0] != get_rows) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// reshape is used by sum_rows and div
|
|
||||||
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
|
||||||
sum_rows->src[0] != reshape5 ||
|
|
||||||
div->src[0] != reshape5) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// sum_rows is used by div
|
|
||||||
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
|
||||||
div->src[1] != sum_rows) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// div/reshape are written
|
|
||||||
if (reshape8->src[0] != div) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ctx->device->subgroup_arithmetic ||
|
if (!ctx->device->subgroup_arithmetic ||
|
||||||
!ctx->device->subgroup_shuffle ||
|
!ctx->device->subgroup_shuffle ||
|
||||||
!ctx->device->subgroup_require_full_support ||
|
!ctx->device->subgroup_require_full_support ||
|
||||||
@@ -12494,10 +12539,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||||||
ctx->num_additional_fused_ops = num_adds - 1;
|
ctx->num_additional_fused_ops = num_adds - 1;
|
||||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
ctx->num_additional_fused_ops = 1;
|
ctx->num_additional_fused_ops = 1;
|
||||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||||
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||||
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||||
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||||
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||||
|
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||||
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||||
@@ -12595,10 +12648,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||||||
ctx->num_additional_fused_ops = num_adds - 1;
|
ctx->num_additional_fused_ops = num_adds - 1;
|
||||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
ctx->num_additional_fused_ops = 1;
|
ctx->num_additional_fused_ops = 1;
|
||||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||||
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||||
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||||
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||||
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||||
|
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||||
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||||
|
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -12730,25 +12791,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||||||
while (first_unused < graph->n_nodes) {
|
while (first_unused < graph->n_nodes) {
|
||||||
std::vector<int> current_set;
|
std::vector<int> current_set;
|
||||||
|
|
||||||
// Avoid reordering topk_moe_norm
|
// Check for fusion patterns and avoid reordering them
|
||||||
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
|
||||||
bool is_topk_moe_norm = true;
|
if (start + (int)pattern.size() <= graph->n_nodes) {
|
||||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
bool is_pattern = true;
|
||||||
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
for (size_t j = 0; j < pattern.size(); ++j) {
|
||||||
is_topk_moe_norm = false;
|
if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
|
||||||
|
is_pattern = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return is_pattern;
|
||||||
}
|
}
|
||||||
if (is_topk_moe_norm) {
|
return false;
|
||||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
};
|
||||||
|
|
||||||
|
auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
|
||||||
|
if (match_pattern(pattern, first_unused)) {
|
||||||
|
for (size_t j = 0; j < pattern.size(); ++j) {
|
||||||
new_order.push_back(graph->nodes[first_unused + j]);
|
new_order.push_back(graph->nodes[first_unused + j]);
|
||||||
used[first_unused + j] = true;
|
used[first_unused + j] = true;
|
||||||
}
|
}
|
||||||
while (first_unused < graph->n_nodes && used[first_unused]) {
|
while (first_unused < graph->n_nodes && used[first_unused]) {
|
||||||
first_unused++;
|
first_unused++;
|
||||||
}
|
}
|
||||||
continue;
|
return true;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (keep_pattern(topk_moe_early_softmax_norm)) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
if (keep_pattern(topk_moe_early_softmax)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (keep_pattern(topk_moe_late_softmax)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// First, grab the next unused node.
|
// First, grab the next unused node.
|
||||||
current_set.push_back(first_unused);
|
current_set.push_back(first_unused);
|
||||||
|
|
||||||
@@ -12766,6 +12846,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||||||
if (is_empty(graph->nodes[j])) {
|
if (is_empty(graph->nodes[j])) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// Don't pull forward nodes from fusion patterns
|
||||||
|
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
||||||
|
match_pattern(topk_moe_early_softmax, j) ||
|
||||||
|
match_pattern(topk_moe_late_softmax, j)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
for (int c = first_unused; c < j; ++c) {
|
for (int c = first_unused; c < j; ++c) {
|
||||||
if (!used[c] &&
|
if (!used[c] &&
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
|
|||||||
{
|
{
|
||||||
uint n_rows;
|
uint n_rows;
|
||||||
uint n_expert_used;
|
uint n_expert_used;
|
||||||
|
float clamp_min;
|
||||||
|
float clamp_max;
|
||||||
};
|
};
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||||
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
|||||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||||
layout(constant_id = 1) const uint n_experts = 512;
|
layout(constant_id = 1) const uint n_experts = 512;
|
||||||
layout(constant_id = 2) const bool with_norm = true;
|
layout(constant_id = 2) const bool with_norm = true;
|
||||||
|
layout(constant_id = 3) const bool late_softmax = false;
|
||||||
|
|
||||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||||
|
|
||||||
@@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
|||||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||||
|
|
||||||
|
const float INFINITY = 1.0 / 0.0;
|
||||||
|
|
||||||
|
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||||
|
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
|
||||||
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const uint idx = lane + i * WARP_SIZE;
|
||||||
|
const bool is_active = !use_limit || (idx < limit);
|
||||||
|
if (is_active) {
|
||||||
|
max_val = max(max_val, vals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
max_val = subgroupMax(max_val);
|
||||||
|
|
||||||
|
float sum = 0.f;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const uint idx = lane + i * WARP_SIZE;
|
||||||
|
const bool is_active = !use_limit || (idx < limit);
|
||||||
|
if (is_active) {
|
||||||
|
const float val = exp(vals[i] - max_val);
|
||||||
|
vals[i] = val;
|
||||||
|
sum += val;
|
||||||
|
} else {
|
||||||
|
vals[i] = 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = subgroupAdd(sum);
|
||||||
|
|
||||||
|
const float inv_sum = 1.0f / sum;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const uint idx = lane + i * WARP_SIZE;
|
||||||
|
const bool is_active = !use_limit || (idx < limit);
|
||||||
|
if (is_active) {
|
||||||
|
vals[i] *= inv_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||||
if (row >= n_rows) {
|
if (row >= n_rows) {
|
||||||
@@ -35,43 +84,16 @@ void main() {
|
|||||||
const uint weights_offset = n_expert_used * row;
|
const uint weights_offset = n_expert_used * row;
|
||||||
const uint ids_offset = n_experts * row;
|
const uint ids_offset = n_experts * row;
|
||||||
|
|
||||||
float logits_r[experts_per_thread];
|
float wt[experts_per_thread];
|
||||||
|
|
||||||
const float INFINITY = 1.0 / 0.0;
|
|
||||||
|
|
||||||
[[unroll]]
|
[[unroll]]
|
||||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||||
const uint expert = i + gl_LocalInvocationID.x;
|
const uint expert = i + gl_LocalInvocationID.x;
|
||||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
float max_val = logits_r[0];
|
if (!late_softmax) {
|
||||||
|
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
||||||
[[unroll]]
|
|
||||||
for (int i = 1; i < experts_per_thread; i++) {
|
|
||||||
const float val = logits_r[i];
|
|
||||||
max_val = max(val, max_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
max_val = subgroupMax(max_val);
|
|
||||||
|
|
||||||
float wt[experts_per_thread];
|
|
||||||
float tmp = 0.f;
|
|
||||||
|
|
||||||
[[unroll]]
|
|
||||||
for (int i = 0; i < experts_per_thread; i++) {
|
|
||||||
const float val = logits_r[i];
|
|
||||||
wt[i] = exp(val - max_val);
|
|
||||||
tmp += wt[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
tmp = subgroupAdd(tmp);
|
|
||||||
|
|
||||||
const float inv_sum = 1.0f / tmp;
|
|
||||||
|
|
||||||
[[unroll]]
|
|
||||||
for (int i = 0; i < experts_per_thread; i++) {
|
|
||||||
wt[i] = wt[i] * inv_sum;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// at this point, each thread holds a portion of softmax,
|
// at this point, each thread holds a portion of softmax,
|
||||||
@@ -82,6 +104,11 @@ void main() {
|
|||||||
|
|
||||||
float output_weights[experts_per_thread];
|
float output_weights[experts_per_thread];
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
output_weights[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
for (int k = 0; k < n_expert_used; k++) {
|
for (int k = 0; k < n_expert_used; k++) {
|
||||||
float max_val = wt[0];
|
float max_val = wt[0];
|
||||||
uint max_expert = gl_LocalInvocationID.x;
|
uint max_expert = gl_LocalInvocationID.x;
|
||||||
@@ -121,6 +148,7 @@ void main() {
|
|||||||
|
|
||||||
if (with_norm) {
|
if (with_norm) {
|
||||||
wt_sum = subgroupAdd(wt_sum);
|
wt_sum = subgroupAdd(wt_sum);
|
||||||
|
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||||
const float inv_sum = 1.0f / wt_sum;
|
const float inv_sum = 1.0f / wt_sum;
|
||||||
|
|
||||||
[[unroll]]
|
[[unroll]]
|
||||||
@@ -129,6 +157,10 @@ void main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (late_softmax) {
|
||||||
|
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
||||||
|
}
|
||||||
|
|
||||||
[[unroll]]
|
[[unroll]]
|
||||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||||
|
|||||||
Reference in New Issue
Block a user