mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : merge build_moe_ffn_from_probs function into build_moe_ffn (#14968)
This commit is contained in:
		| @@ -785,13 +785,20 @@ ggml_tensor * llm_graph_context::build_moe_ffn( | ||||
|                 bool   scale_w, | ||||
|                float   w_scale, | ||||
|          llama_expert_gating_func_type gating_op, | ||||
|                  int   il) const { | ||||
|                  int   il, | ||||
|          ggml_tensor * probs_in) const { | ||||
|     const int64_t n_embd   = cur->ne[0]; | ||||
|     const int64_t n_tokens = cur->ne[1]; | ||||
|     const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN | ||||
|  | ||||
|     ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] | ||||
|     ggml_tensor * logits = nullptr; | ||||
|  | ||||
|     if (probs_in == nullptr) { | ||||
|         logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] | ||||
|         cb(logits, "ffn_moe_logits", il); | ||||
|     } else { | ||||
|         logits = probs_in; | ||||
|     } | ||||
|  | ||||
|     ggml_tensor * probs = nullptr; | ||||
|     switch (gating_op) { | ||||
| @@ -884,6 +891,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( | ||||
|                 cur = ggml_gelu(ctx0, cur); | ||||
|                 cb(cur, "ffn_moe_gelu", il); | ||||
|             } break; | ||||
|         case LLM_FFN_RELU: | ||||
|             if (gate_exps) { | ||||
|                 cur = ggml_reglu_split(ctx0, cur, up); | ||||
|                 cb(cur, "ffn_moe_reglu", il); | ||||
|             } else { | ||||
|                 cur = ggml_relu(ctx0, cur); | ||||
|                 cb(cur, "ffn_moe_relu", il); | ||||
|             } break; | ||||
|         default: | ||||
|             GGML_ABORT("fatal error"); | ||||
|     } | ||||
| @@ -927,100 +942,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( | ||||
|     return moe_out; | ||||
| } | ||||
|  | ||||
| ggml_tensor * llm_graph_context::build_moe_ffn_from_probs( | ||||
|          ggml_tensor * cur, | ||||
|          ggml_tensor * probs, | ||||
|          ggml_tensor * up_exps, | ||||
|          ggml_tensor * gate_exps, | ||||
|          ggml_tensor * down_exps, | ||||
|          ggml_tensor * exp_probs_b, | ||||
|              int64_t   n_expert, | ||||
|              int64_t   n_expert_used, | ||||
|              llama_expert_gating_func_type gating_op, | ||||
|                  int   il) const { | ||||
|     const int64_t n_embd   = cur->ne[0]; | ||||
|     const int64_t n_tokens = cur->ne[1]; | ||||
|  | ||||
|     // add experts selection bias - introduced in DeepSeek V3 | ||||
|     // leave probs unbiased as it's later used to get expert weights | ||||
|     ggml_tensor * selection_probs = probs; | ||||
|     if (exp_probs_b != nullptr) { | ||||
|         selection_probs = ggml_add(ctx0, probs, exp_probs_b); | ||||
|         cb(selection_probs, "ffn_moe_probs_biased", il); | ||||
|     } | ||||
|  | ||||
|     // select experts | ||||
|     ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] | ||||
|     cb(selected_experts->src[0], "ffn_moe_argsort", il); | ||||
|     cb(selected_experts, "ffn_moe_topk", il); | ||||
|  | ||||
|     ggml_tensor * weights = ggml_get_rows(ctx0, | ||||
|             ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] | ||||
|     cb(weights, "ffn_moe_weights", il); | ||||
|  | ||||
|     weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); | ||||
|      if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) { | ||||
|         weights = ggml_soft_max(ctx0, weights); | ||||
|     } else { | ||||
|         weights = ggml_sigmoid(ctx0, weights); | ||||
|         ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] | ||||
|         cb(weights_sum, "ffn_moe_weights_sum", il); | ||||
|  | ||||
|         weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] | ||||
|         cb(weights, "ffn_moe_weights_norm", il); | ||||
|     } | ||||
|  | ||||
|     weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); | ||||
|  | ||||
|     cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); | ||||
|  | ||||
|     ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] | ||||
|     cb(up, "ffn_moe_up", il); | ||||
|  | ||||
|     ggml_tensor * experts = nullptr; | ||||
|     cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] | ||||
|     cb(cur, "ffn_moe_gate", il); | ||||
|  | ||||
|     cur = ggml_reglu_split(ctx0, cur, up); | ||||
|     cb(cur, "ffn_moe_reglu", il); | ||||
|  | ||||
|     experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] | ||||
|     cb(experts, "ffn_moe_down", il); | ||||
|  | ||||
|     experts = ggml_mul(ctx0, experts, weights); | ||||
|     cb(cur, "ffn_moe_weighted", il); | ||||
|  | ||||
|     ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; | ||||
|  | ||||
|     assert(n_expert_used > 0); | ||||
|  | ||||
|     // order the views before the adds | ||||
|     for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { | ||||
|         cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); | ||||
|  | ||||
|         ggml_build_forward_expand(gf, cur_experts[i]); | ||||
|     } | ||||
|  | ||||
|     // aggregate experts | ||||
|     // note: here we explicitly use hparams.n_expert_used instead of n_expert_used | ||||
|     //       to avoid potentially a large number of add nodes during warmup | ||||
|     //       ref: https://github.com/ggml-org/llama.cpp/pull/14753 | ||||
|     ggml_tensor * moe_out = cur_experts[0]; | ||||
|  | ||||
|     for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { | ||||
|         moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); | ||||
|     } | ||||
|  | ||||
|     if (n_expert_used == 1) { | ||||
|         // avoid returning a non-contiguous tensor | ||||
|         moe_out = ggml_cont(ctx0, moe_out); | ||||
|     } | ||||
|  | ||||
|     cb(moe_out, "ffn_moe_out", il); | ||||
|  | ||||
|     return moe_out; | ||||
| } | ||||
|  | ||||
| // input embeddings with optional lora | ||||
| ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { | ||||
|     const int64_t n_embd = hparams.n_embd; | ||||
|   | ||||
| @@ -631,19 +631,8 @@ struct llm_graph_context { | ||||
|                     bool   scale_w, | ||||
|                    float   w_scale, | ||||
|             llama_expert_gating_func_type gating_op, | ||||
|                      int   il) const; | ||||
|  | ||||
|     ggml_tensor * build_moe_ffn_from_probs( | ||||
|              ggml_tensor * cur, | ||||
|              ggml_tensor * probs, | ||||
|              ggml_tensor * up_exps, | ||||
|              ggml_tensor * gate_exps, | ||||
|              ggml_tensor * down_exps, | ||||
|              ggml_tensor * exp_probs_b, | ||||
|                  int64_t   n_expert, | ||||
|                  int64_t   n_expert_used, | ||||
|             llama_expert_gating_func_type gating_op, | ||||
|                      int   il) const; | ||||
|                      int   il, | ||||
|              ggml_tensor * probs_in = nullptr) const; | ||||
|  | ||||
|     // | ||||
|     // inputs | ||||
|   | ||||
| @@ -17320,10 +17320,18 @@ struct llm_build_smallthinker : public llm_graph_context{ | ||||
|             cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); | ||||
|             cb(cur, "ffn_norm", il); | ||||
|  | ||||
|             ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps, | ||||
|                                                 model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, | ||||
|                                                 nullptr, n_expert, n_expert_used, | ||||
|                                                 static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il); | ||||
|             ggml_tensor * ffn_out = | ||||
|                 build_moe_ffn(cur, | ||||
|                         nullptr, | ||||
|                         model.layers[il].ffn_up_exps, | ||||
|                         model.layers[il].ffn_gate_exps, | ||||
|                         model.layers[il].ffn_down_exps, | ||||
|                         nullptr, | ||||
|                         n_expert, n_expert_used, | ||||
|                         LLM_FFN_RELU, true, | ||||
|                         false, 0.0, | ||||
|                         static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), | ||||
|                         il, probs); | ||||
|  | ||||
|             cb(ffn_out, "ffn_out", il); | ||||
|             cur = ffn_out; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Dongliang Wei
					Dongliang Wei