mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	nvidia nemotron nano v2 (nemotronh) (#15507)
* feat: Add NEMOTRONH to python arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add NEMOTRONH to c++ arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add NEMOTRONH to llama-arch layer map https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First pass at conversion for nemotronh https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add a verbose log for each tensor loaded This is really helpful for diagnosing mismatches between the expected and received tensors https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First (broken) pass at nemotronh model architecture It generates tokens, just not valid ones! https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Explicitly enable add_bos_token during conversion The `tokenizer.json`/`tokenizer_config.json` in the model are a bit contradictory. In the config, add_bos_token is set to False, but the tokenizer model itself has a post_processor that adds the BOS token via type: TemplateProcessing https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use relu2 (LLM_FFN_RELU_SQR) for activation in FFN layers https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Only allocate attention cache for attention layers (not non-recurrent) https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Move residual add to after every block https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use the correct norm tensor for the MLP blocks https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * Nemotron-H: MLP gate cleanup (pass NULL for unused gate) This model does not use a gate in MLP blocks; pass NULLs for gate tensors to make intent clear and avoid unused-pointer noise. * SSM: respect ssm_dt_rank for dt_dim when provided Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16). * fix: plamo2 - revert dt_dim to default (remove ssm_dt_rank usage) * Rename nemotronh to nemotron_h for consistency - Update architecture name from NEMOTRONH to NEMOTRON_H in constants.py - Change architecture string from 'nemotronh' to 'nemotron_h' in all files - Update enum LLM_ARCH_NEMOTRONH to LLM_ARCH_NEMOTRON_H - Update class name llm_build_nemotronh to llm_build_nemotron_h - Consistent naming with underscore convention (nemotron_h vs nemotronh) * feat: Support conversion for older NemotronH models https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Maicon Domingues <dominguesm@outlook.com> Co-authored-by: weatherman <fxdstudios@gmail.com>
This commit is contained in:
		| @@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { | ||||
|     { LLM_ARCH_T5ENCODER,        "t5encoder"        }, | ||||
|     { LLM_ARCH_JAIS,             "jais"             }, | ||||
|     { LLM_ARCH_NEMOTRON,         "nemotron"         }, | ||||
|     { LLM_ARCH_NEMOTRON_H,       "nemotron_h"       }, | ||||
|     { LLM_ARCH_EXAONE,           "exaone"           }, | ||||
|     { LLM_ARCH_EXAONE4,          "exaone4"          }, | ||||
|     { LLM_ARCH_RWKV6,            "rwkv6"            }, | ||||
| @@ -1550,6 +1551,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N | ||||
|             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_NEMOTRON_H, | ||||
|         { | ||||
|             { LLM_TENSOR_TOKEN_EMBD,     "token_embd" }, | ||||
|             { LLM_TENSOR_OUTPUT_NORM,    "output_norm" }, | ||||
|             { LLM_TENSOR_OUTPUT,         "output" }, | ||||
|             { LLM_TENSOR_ATTN_NORM,      "blk.%d.attn_norm" }, | ||||
|             // mamba(2) ssm layers | ||||
|             { LLM_TENSOR_SSM_IN,         "blk.%d.ssm_in" }, | ||||
|             { LLM_TENSOR_SSM_CONV1D,     "blk.%d.ssm_conv1d" }, | ||||
|             { LLM_TENSOR_SSM_DT,         "blk.%d.ssm_dt" }, | ||||
|             { LLM_TENSOR_SSM_A,          "blk.%d.ssm_a" }, | ||||
|             { LLM_TENSOR_SSM_D,          "blk.%d.ssm_d" }, | ||||
|             { LLM_TENSOR_SSM_NORM,       "blk.%d.ssm_norm" }, | ||||
|             { LLM_TENSOR_SSM_OUT,        "blk.%d.ssm_out" }, | ||||
|             // attention layers | ||||
|             { LLM_TENSOR_ATTN_Q,         "blk.%d.attn_q" }, | ||||
|             { LLM_TENSOR_ATTN_K,         "blk.%d.attn_k" }, | ||||
|             { LLM_TENSOR_ATTN_V,         "blk.%d.attn_v" }, | ||||
|             { LLM_TENSOR_ATTN_OUT,       "blk.%d.attn_output" }, | ||||
|             // dense FFN | ||||
|             { LLM_TENSOR_FFN_DOWN,       "blk.%d.ffn_down" }, | ||||
|             { LLM_TENSOR_FFN_UP,         "blk.%d.ffn_up" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_EXAONE, | ||||
|         { | ||||
| @@ -2355,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { | ||||
|         case LLM_ARCH_PLAMO2: | ||||
|         case LLM_ARCH_GRANITE_HYBRID: | ||||
|         case LLM_ARCH_LFM2: | ||||
|         case LLM_ARCH_NEMOTRON_H: | ||||
|             return true; | ||||
|         default: | ||||
|             return false; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gabe Goodhart
					Gabe Goodhart