mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
nvidia nemotron nano v2 (nemotronh) (#15507)
* feat: Add NEMOTRONH to python arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add NEMOTRONH to c++ arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add NEMOTRONH to llama-arch layer map https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First pass at conversion for nemotronh https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add a verbose log for each tensor loaded This is really helpful for diagnosing mismatches between the expected and received tensors https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First (broken) pass at nemotronh model architecture It generates tokens, just not valid ones! https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Explicitly enable add_bos_token during conversion The `tokenizer.json`/`tokenizer_config.json` in the model are a bit contradictory. In the config, add_bos_token is set to False, but the tokenizer model itself has a post_processor that adds the BOS token via type: TemplateProcessing https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use relu2 (LLM_FFN_RELU_SQR) for activation in FFN layers https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Only allocate attention cache for attention layers (not non-recurrent) https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Move residual add to after every block https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use the correct norm tensor for the MLP blocks https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * Nemotron-H: MLP gate cleanup (pass NULL for unused gate) This model does not use a gate in MLP blocks; pass NULLs for gate tensors to make intent clear and avoid unused-pointer noise. * SSM: respect ssm_dt_rank for dt_dim when provided Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16). * fix: plamo2 - revert dt_dim to default (remove ssm_dt_rank usage) * Rename nemotronh to nemotron_h for consistency - Update architecture name from NEMOTRONH to NEMOTRON_H in constants.py - Change architecture string from 'nemotronh' to 'nemotron_h' in all files - Update enum LLM_ARCH_NEMOTRONH to LLM_ARCH_NEMOTRON_H - Update class name llm_build_nemotronh to llm_build_nemotron_h - Consistent naming with underscore convention (nemotron_h vs nemotronh) * feat: Support conversion for older NemotronH models https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Maicon Domingues <dominguesm@outlook.com> Co-authored-by: weatherman <fxdstudios@gmail.com>
This commit is contained in:
@@ -7546,9 +7546,13 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
]
|
||||
|
||||
# n_group and d_inner are used during reshape_tensors for mamba2
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||
self.n_group = self.find_hparam(["n_groups"])
|
||||
self.d_inner = self.find_hparam(["expand"]) * self.d_model
|
||||
# NOTE: Explicitly include hparam prefix prefix for d_model to
|
||||
# disambiguate with top-level head_dim
|
||||
# NOTE 2: If needed for future models, this can be isolated in a method
|
||||
# to separate the prefix setting and teh keys used
|
||||
self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
|
||||
self.n_group = self.find_hparam(["n_groups", "num_groups"])
|
||||
self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
|
||||
|
||||
def get_attn_layers(self):
|
||||
# Explicit list of layer type names
|
||||
@@ -7609,12 +7613,12 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
|
||||
## Mamba mixer params ##
|
||||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
|
||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
|
||||
self.gguf_writer.add_ssm_group_count(self.n_group)
|
||||
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
||||
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
|
||||
# in llama.cpp
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
|
||||
|
||||
## Attention params ##
|
||||
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
@@ -7641,6 +7645,55 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
Mamba2Model.set_vocab(self)
|
||||
|
||||
|
||||
@ModelBase.register("NemotronHForCausalLM")
|
||||
class NemotronHModel(GraniteHybridModel):
|
||||
"""Hybrid mamba2/attention model from NVIDIA"""
|
||||
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Save the top-level head_dim for later
|
||||
self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
|
||||
assert self.head_dim is not None, "Could not find the attention head dim in config"
|
||||
|
||||
# Don't use expand to calculate d_inner
|
||||
self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
|
||||
|
||||
# Update the ssm / attn / mlp layers
|
||||
# M: Mamba2, *: Attention, -: MLP
|
||||
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
|
||||
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
|
||||
|
||||
def get_attn_layers(self):
|
||||
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
|
||||
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_key_length(self.head_dim)
|
||||
self.gguf_writer.add_value_length(self.head_dim)
|
||||
|
||||
# Set feed_forward_length
|
||||
# NOTE: This will trigger an override warning. This is preferrable to
|
||||
# duplicating all the parent logic
|
||||
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
|
||||
self.gguf_writer.add_feed_forward_length([
|
||||
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
|
||||
])
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
|
||||
# The tokenizer _does_ add a BOS token (via post_processor type
|
||||
# TemplateProcessing) but does not set add_bos_token to true in the
|
||||
# config, so we need to explicitly override it here.
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
|
||||
@ModelBase.register("BailingMoeForCausalLM")
|
||||
class BailingMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE
|
||||
|
||||
Reference in New Issue
Block a user