mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : support for Llama-3_1-Nemotron-51B (#10669)
* conflict resolution * move comments after bracket to its own line
This commit is contained in:
		
							
								
								
									
										267
									
								
								src/llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										267
									
								
								src/llama.cpp
									
									
									
									
									
								
							| @@ -146,6 +146,7 @@ static std::string format(const char * fmt, ...) { | ||||
|  | ||||
| enum llm_arch { | ||||
|     LLM_ARCH_LLAMA, | ||||
|     LLM_ARCH_DECI, | ||||
|     LLM_ARCH_FALCON, | ||||
|     LLM_ARCH_BAICHUAN, | ||||
|     LLM_ARCH_GROK, | ||||
| @@ -203,6 +204,7 @@ enum llm_arch { | ||||
|  | ||||
| static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { | ||||
|     { LLM_ARCH_LLAMA,            "llama"            }, | ||||
|     { LLM_ARCH_DECI,             "deci"            }, | ||||
|     { LLM_ARCH_FALCON,           "falcon"           }, | ||||
|     { LLM_ARCH_GROK,             "grok"             }, | ||||
|     { LLM_ARCH_GPT2,             "gpt2"             }, | ||||
| @@ -674,6 +676,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N | ||||
|             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_DECI, | ||||
|         { | ||||
|             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" }, | ||||
|             { LLM_TENSOR_OUTPUT_NORM,     "output_norm" }, | ||||
|             { LLM_TENSOR_OUTPUT,          "output" }, | ||||
|             { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" }, | ||||
|             { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" }, | ||||
|             { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" }, | ||||
|             { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" }, | ||||
|             { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" }, | ||||
|             { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" }, | ||||
|             { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" }, | ||||
|             { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" }, | ||||
|             { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" }, | ||||
|             { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" }, | ||||
|             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" }, | ||||
|             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" }, | ||||
|             { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" }, | ||||
|             { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" }, | ||||
|             { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" }, | ||||
|             { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" }, | ||||
|             { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" }, | ||||
|             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_BAICHUAN, | ||||
|         { | ||||
| @@ -5694,7 +5722,7 @@ static void llm_load_hparams( | ||||
|  | ||||
|         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); | ||||
|  | ||||
|         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { | ||||
|         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { | ||||
|             if (hparams.n_rot != hparams.n_embd_head_k) { | ||||
|                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); | ||||
|             } | ||||
| @@ -5734,6 +5762,15 @@ static void llm_load_hparams( | ||||
|                     } | ||||
|                 } | ||||
|             } break; | ||||
|         case LLM_ARCH_DECI: | ||||
|             { | ||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||
|                 switch (hparams.n_layer) { | ||||
|                     case 32: model.type = e_model::MODEL_7B; break; | ||||
|                     case 80: model.type = e_model::MODEL_70B; break; | ||||
|                     default: model.type = e_model::MODEL_UNKNOWN; | ||||
|                 } | ||||
|             } break; | ||||
|         case LLM_ARCH_MINICPM: | ||||
|             { | ||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||
| @@ -7939,6 +7976,68 @@ static bool llm_load_tensors( | ||||
|                         } | ||||
|                     } | ||||
|                 } break; | ||||
|             case LLM_ARCH_DECI: | ||||
|                 { | ||||
|                     model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); | ||||
|  | ||||
|                     // output | ||||
|                     model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); | ||||
|                     model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|  | ||||
|                     // if output is NULL, init from the input tok embed | ||||
|                     if (model.output == NULL) { | ||||
|                         model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); | ||||
|                     } | ||||
|  | ||||
|                     for (int i = 0; i < n_layer; ++i) { | ||||
|                         auto & layer = model.layers[i]; | ||||
|                         const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(i); | ||||
|                         const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(i); | ||||
|                         const int64_t n_embd_gqa    = hparams.n_embd_v_gqa(i); | ||||
|                         const int64_t n_ff          = hparams.n_ff(i); | ||||
|                         const int64_t n_head        = hparams.n_head(i); | ||||
|                         const int64_t n_head_kv     = hparams.n_head_kv(i); | ||||
|  | ||||
|                         if (n_head_kv == 0 && n_head > 0) { | ||||
|                             // linear attention for DeciLMCausalModel | ||||
|                             layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); | ||||
|                             layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); | ||||
|                         } | ||||
|                         else if (n_head_kv > 0) { | ||||
|                             layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); | ||||
|  | ||||
|                             layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0); | ||||
|                             layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0); | ||||
|                             layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0); | ||||
|                             layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); | ||||
|                         } | ||||
|  | ||||
|                         // optional bias tensors | ||||
|                         layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|                         layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|                         layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|                         layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|  | ||||
|                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); | ||||
|  | ||||
|                         if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { | ||||
|                             layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); | ||||
|                             layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); | ||||
|                         } | ||||
|                         else { | ||||
|                             layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); | ||||
|                         } | ||||
|  | ||||
|                         layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0); | ||||
|                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0); | ||||
|                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0); | ||||
|  | ||||
|                         // optional MLP bias | ||||
|                         layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|                         layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|                         layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|                     } | ||||
|                 } break; | ||||
|             case LLM_ARCH_MINICPM3: | ||||
|                 { | ||||
|                     const int64_t n_embd_head_qk_rope = hparams.n_rot; | ||||
| @@ -11308,6 +11407,167 @@ struct llm_build_context { | ||||
|         return gf; | ||||
|     } | ||||
|  | ||||
|     struct ggml_cgraph * build_deci() { | ||||
|         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); | ||||
|  | ||||
|         // mutable variable, needed during the last layer of the computation to skip unused tokens | ||||
|         int32_t n_tokens = this->n_tokens; | ||||
|  | ||||
|         const int64_t n_embd_head = hparams.n_embd_head_v; | ||||
|         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); | ||||
|         GGML_ASSERT(n_embd_head == hparams.n_rot); | ||||
|  | ||||
|         struct ggml_tensor * cur; | ||||
|         struct ggml_tensor * inpL; | ||||
|  | ||||
|         inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); | ||||
|  | ||||
|         // inp_pos - contains the positions | ||||
|         struct ggml_tensor * inp_pos = build_inp_pos(); | ||||
|  | ||||
|         // KQ_mask (mask for 1 head, it will be broadcasted to all heads) | ||||
|         struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); | ||||
|  | ||||
|         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|             struct ggml_tensor * inpSA = inpL; | ||||
|             const int64_t n_head_kv = hparams.n_head_kv(il); | ||||
|             const int64_t n_head    = hparams.n_head(il); | ||||
|  | ||||
|             if (n_head == 0) { | ||||
|                 // attention-free layer of Llama-3_1-Nemotron-51B | ||||
|                 cur = inpL; | ||||
|             } else { | ||||
|                 // norm | ||||
|                 cur = llm_build_norm(ctx0, inpL, hparams, | ||||
|                         model.layers[il].attn_norm, NULL, | ||||
|                         LLM_NORM_RMS, cb, il); | ||||
|                 cb(cur, "attn_norm", il); | ||||
|             } | ||||
|  | ||||
|             if (n_head > 0 && n_head_kv == 0) { | ||||
|                 // "linear attention" of Llama-3_1-Nemotron-51B | ||||
|                 cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); | ||||
|                 cb(cur, "wo", il); | ||||
|             } else if (n_head > 0) { | ||||
|                 // self-attention | ||||
|                 // rope freq factors for llama3; may return nullptr for llama2 and other models | ||||
|                 struct ggml_tensor * rope_factors = build_rope_factors(il); | ||||
|  | ||||
|                 // compute Q and K and RoPE them | ||||
|                 struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); | ||||
|                 cb(Qcur, "Qcur", il); | ||||
|                 if (model.layers[il].bq) { | ||||
|                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); | ||||
|                     cb(Qcur, "Qcur", il); | ||||
|                 } | ||||
|  | ||||
|                 struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
|                 if (model.layers[il].bk) { | ||||
|                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); | ||||
|                     cb(Kcur, "Kcur", il); | ||||
|                 } | ||||
|  | ||||
|                 struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); | ||||
|                 cb(Vcur, "Vcur", il); | ||||
|                 if (model.layers[il].bv) { | ||||
|                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); | ||||
|                     cb(Vcur, "Vcur", il); | ||||
|                 } | ||||
|  | ||||
|                 Qcur = ggml_rope_ext( | ||||
|                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, | ||||
|                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, | ||||
|                     ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(Qcur, "Qcur", il); | ||||
|  | ||||
|                 Kcur = ggml_rope_ext( | ||||
|                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, | ||||
|                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, | ||||
|                     ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
|  | ||||
|                 cur = llm_build_kv(ctx0, lctx, kv_self, gf, | ||||
|                         model.layers[il].wo, model.layers[il].bo, | ||||
|                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); | ||||
|             } | ||||
|  | ||||
|             if (il == n_layer - 1) { | ||||
|                 // skip computing output for unused tokens | ||||
|                 struct ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||||
|                 n_tokens = n_outputs; | ||||
|                 cur   = ggml_get_rows(ctx0,   cur, inp_out_ids); | ||||
|                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); | ||||
|             } | ||||
|  | ||||
|             // For Granite architecture | ||||
|             if (hparams.f_residual_scale) { | ||||
|                 cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); | ||||
|             } | ||||
|  | ||||
|             // modified to support attention-free layer of Llama-3_1-Nemotron-51B | ||||
|             struct ggml_tensor * ffn_inp = cur; | ||||
|             if (n_head > 0) { | ||||
|                 ffn_inp = ggml_add(ctx0, cur, inpSA); | ||||
|                 cb(ffn_inp, "ffn_inp", il); | ||||
|             } | ||||
|  | ||||
|             // feed-forward network | ||||
|             if (model.layers[il].ffn_gate_inp == nullptr) { | ||||
|                 cur = llm_build_norm(ctx0, ffn_inp, hparams, | ||||
|                         model.layers[il].ffn_norm, NULL, | ||||
|                         LLM_NORM_RMS, cb, il); | ||||
|                 cb(cur, "ffn_norm", il); | ||||
|  | ||||
|                 cur = llm_build_ffn(ctx0, lctx, cur, | ||||
|                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL, | ||||
|                         model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, | ||||
|                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, | ||||
|                         NULL, | ||||
|                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il); | ||||
|                 cb(cur, "ffn_out", il); | ||||
|             } | ||||
|  | ||||
|             // For Granite architecture | ||||
|             if (hparams.f_residual_scale) { | ||||
|                 cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); | ||||
|             } | ||||
|  | ||||
|             cur = ggml_add(ctx0, cur, ffn_inp); | ||||
|             cb(cur, "ffn_out", il); | ||||
|  | ||||
|             cur = lctx.cvec.apply_to(ctx0, cur, il); | ||||
|             cb(cur, "l_out", il); | ||||
|  | ||||
|             // input for next layer | ||||
|             inpL = cur; | ||||
|         } | ||||
|  | ||||
|         cur = inpL; | ||||
|  | ||||
|         cur = llm_build_norm(ctx0, cur, hparams, | ||||
|                 model.output_norm, NULL, | ||||
|                 LLM_NORM_RMS, cb, -1); | ||||
|         cb(cur, "result_norm", -1); | ||||
|  | ||||
|         // lm_head | ||||
|         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); | ||||
|  | ||||
|         // For Granite architecture | ||||
|         if (hparams.f_logit_scale) { | ||||
|             cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); | ||||
|         } | ||||
|  | ||||
|         cb(cur, "result_output", -1); | ||||
|  | ||||
|         ggml_build_forward_expand(gf, cur); | ||||
|  | ||||
|         return gf; | ||||
|     } | ||||
|  | ||||
|     struct ggml_cgraph * build_baichuan() { | ||||
|         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); | ||||
|  | ||||
| @@ -17422,6 +17682,10 @@ static struct ggml_cgraph * llama_build_graph( | ||||
|             { | ||||
|                 result = llm.build_llama(); | ||||
|             } break; | ||||
|         case LLM_ARCH_DECI: | ||||
|             { | ||||
|                 result = llm.build_deci(); | ||||
|             } break; | ||||
|         case LLM_ARCH_BAICHUAN: | ||||
|             { | ||||
|                 result = llm.build_baichuan(); | ||||
| @@ -20797,6 +21061,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { | ||||
|  | ||||
|         // use what we call a normal RoPE, operating on pairs of consecutive head values | ||||
|         case LLM_ARCH_LLAMA: | ||||
|         case LLM_ARCH_DECI: | ||||
|         case LLM_ARCH_BAICHUAN: | ||||
|         case LLM_ARCH_STARCODER: | ||||
|         case LLM_ARCH_PLAMO: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 ymcki
					ymcki