mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : support Jamba
This commit is contained in:
		| @@ -2300,7 +2300,7 @@ class MambaModel(Model): | ||||
|         self.gguf_writer.add_embedding_length(d_model) | ||||
|         self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading | ||||
|         self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading | ||||
|         self.gguf_writer.add_block_count(self.hparams["n_layer"]) | ||||
|         self.gguf_writer.add_block_count(self.block_count) | ||||
|         self.gguf_writer.add_ssm_conv_kernel(d_conv) | ||||
|         self.gguf_writer.add_ssm_inner_size(d_inner) | ||||
|         self.gguf_writer.add_ssm_state_size(d_state) | ||||
| @@ -2346,6 +2346,107 @@ class MambaModel(Model): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @Model.register("JambaForCausalLM") | ||||
| class JambaModel(Model): | ||||
|     model_arch = gguf.MODEL_ARCH.JAMBA | ||||
|  | ||||
|     def get_vocab_base_pre(self, tokenizer) -> str: | ||||
|         del tokenizer  # unused | ||||
|  | ||||
|         return "gpt-2" | ||||
|  | ||||
|     def set_gguf_parameters(self): | ||||
|         d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) | ||||
|         d_conv  = self.find_hparam(["mamba_d_conv"],  optional=True) or 4 | ||||
|         d_inner = self.hparams["mamba_expand"] * d_model | ||||
|         d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16 | ||||
|         # ceiling division | ||||
|         # ref: https://stackoverflow.com/a/17511341/22827863 | ||||
|         # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 | ||||
|         dt_rank      = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16) | ||||
|         rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6 | ||||
|         n_kv_head = self.hparams["num_key_value_heads"] | ||||
|         attn_offset = self.hparams["attn_layer_offset"] | ||||
|         attn_period = self.hparams["attn_layer_period"] | ||||
|         n_kv_vec = [0 for _ in range(attn_offset)] + [ | ||||
|             n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) | ||||
|         ] | ||||
|  | ||||
|         self.gguf_writer.add_name(self.dir_model.name) | ||||
|         self.gguf_writer.add_block_count(self.block_count) | ||||
|         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) | ||||
|         self.gguf_writer.add_embedding_length(d_model) | ||||
|         self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) | ||||
|         self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) | ||||
|         self.gguf_writer.add_head_count_kv(n_kv_vec) | ||||
|         self.gguf_writer.add_ssm_conv_kernel(d_conv) | ||||
|         self.gguf_writer.add_ssm_inner_size(d_inner) | ||||
|         self.gguf_writer.add_ssm_state_size(d_state) | ||||
|         self.gguf_writer.add_ssm_time_step_rank(dt_rank) | ||||
|         self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) | ||||
|         self.gguf_writer.add_expert_count(self.hparams["num_experts"]) | ||||
|         self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) | ||||
|         self.gguf_writer.add_file_type(self.ftype) | ||||
|  | ||||
|     _experts: list[dict[str, Tensor]] | None = None | ||||
|  | ||||
|     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||||
|  | ||||
|         # process the experts separately | ||||
|         if ".feed_forward.experts." in name: | ||||
|             n_experts = self.hparams["num_experts"] | ||||
|  | ||||
|             assert bid is not None | ||||
|  | ||||
|             if self._experts is None: | ||||
|                 self._experts = [{} for _ in range(self.block_count)] | ||||
|  | ||||
|             self._experts[bid][name] = data_torch | ||||
|  | ||||
|             if len(self._experts[bid]) >= n_experts * 3: | ||||
|  | ||||
|                 # merge the experts into a single 3d tensor | ||||
|                 for wid in ["down_proj", "gate_proj", "up_proj"]: | ||||
|                     datas: list[Tensor] = [] | ||||
|  | ||||
|                     for xid in range(n_experts): | ||||
|                         ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight" | ||||
|                         datas.append(self._experts[bid][ename]) | ||||
|                         del self._experts[bid][ename] | ||||
|  | ||||
|                     data_torch = torch.stack(datas, dim=0) | ||||
|  | ||||
|                     # using the same merged name as qwen2moe | ||||
|                     merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight" | ||||
|  | ||||
|                     new_name = self.map_tensor_name(merged_name) | ||||
|  | ||||
|                     yield new_name, data_torch | ||||
|             return | ||||
|  | ||||
|         new_name = self.map_tensor_name(name) | ||||
|  | ||||
|         if name.endswith(".A_log"): | ||||
|             logger.debug("A_log --> A ==> " + new_name) | ||||
|             data_torch = -torch.exp(data_torch) | ||||
|  | ||||
|         yield new_name, data_torch | ||||
|  | ||||
|     # same as Mamba | ||||
|     def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: | ||||
|         del n_dims  # unused | ||||
|  | ||||
|         return bid is not None and new_name in ( | ||||
|             self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ | ||||
|                 gguf.MODEL_TENSOR.SSM_CONV1D, | ||||
|                 gguf.MODEL_TENSOR.SSM_X, | ||||
|                 gguf.MODEL_TENSOR.SSM_DT, | ||||
|                 gguf.MODEL_TENSOR.SSM_A, | ||||
|                 gguf.MODEL_TENSOR.SSM_D, | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @Model.register("CohereForCausalLM") | ||||
| class CommandR2Model(Model): | ||||
|     model_arch = gguf.MODEL_ARCH.COMMAND_R | ||||
|   | ||||
| @@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum): | ||||
|     GEMMA      = auto() | ||||
|     STARCODER2 = auto() | ||||
|     MAMBA      = auto() | ||||
|     JAMBA      = auto() | ||||
|     XVERSE     = auto() | ||||
|     COMMAND_R  = auto() | ||||
|     DBRX       = auto() | ||||
| @@ -180,7 +181,10 @@ class MODEL_TENSOR(IntEnum): | ||||
|     SSM_CONV1D         = auto() | ||||
|     SSM_X              = auto() | ||||
|     SSM_DT             = auto() | ||||
|     SSM_DT_NORM        = auto() | ||||
|     SSM_A              = auto() | ||||
|     SSM_B_NORM         = auto() | ||||
|     SSM_C_NORM         = auto() | ||||
|     SSM_D              = auto() | ||||
|     SSM_OUT            = auto() | ||||
|  | ||||
| @@ -214,6 +218,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { | ||||
|     MODEL_ARCH.GEMMA:          "gemma", | ||||
|     MODEL_ARCH.STARCODER2:     "starcoder2", | ||||
|     MODEL_ARCH.MAMBA:          "mamba", | ||||
|     MODEL_ARCH.JAMBA:          "jamba", | ||||
|     MODEL_ARCH.XVERSE:         "xverse", | ||||
|     MODEL_ARCH.COMMAND_R:      "command-r", | ||||
|     MODEL_ARCH.DBRX:           "dbrx", | ||||
| @@ -259,7 +264,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { | ||||
|     MODEL_TENSOR.SSM_CONV1D:         "blk.{bid}.ssm_conv1d", | ||||
|     MODEL_TENSOR.SSM_X:              "blk.{bid}.ssm_x", | ||||
|     MODEL_TENSOR.SSM_DT:             "blk.{bid}.ssm_dt", | ||||
|     MODEL_TENSOR.SSM_DT_NORM:        "blk.{bid}.ssm_dt_norm", | ||||
|     MODEL_TENSOR.SSM_A:              "blk.{bid}.ssm_a", | ||||
|     MODEL_TENSOR.SSM_B_NORM:         "blk.{bid}.ssm_b_norm", | ||||
|     MODEL_TENSOR.SSM_C_NORM:         "blk.{bid}.ssm_c_norm", | ||||
|     MODEL_TENSOR.SSM_D:              "blk.{bid}.ssm_d", | ||||
|     MODEL_TENSOR.SSM_OUT:            "blk.{bid}.ssm_out", | ||||
| } | ||||
| @@ -678,6 +686,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { | ||||
|         MODEL_TENSOR.SSM_D, | ||||
|         MODEL_TENSOR.SSM_OUT, | ||||
|     ], | ||||
|     MODEL_ARCH.JAMBA: [ | ||||
|         MODEL_TENSOR.TOKEN_EMBD, | ||||
|         MODEL_TENSOR.OUTPUT_NORM, | ||||
|         MODEL_TENSOR.OUTPUT, | ||||
|         MODEL_TENSOR.ATTN_NORM, | ||||
|         MODEL_TENSOR.ATTN_Q, | ||||
|         MODEL_TENSOR.ATTN_K, | ||||
|         MODEL_TENSOR.ATTN_V, | ||||
|         MODEL_TENSOR.ATTN_OUT, | ||||
|         MODEL_TENSOR.SSM_IN, | ||||
|         MODEL_TENSOR.SSM_CONV1D, | ||||
|         MODEL_TENSOR.SSM_X, | ||||
|         MODEL_TENSOR.SSM_DT, | ||||
|         MODEL_TENSOR.SSM_DT_NORM, | ||||
|         MODEL_TENSOR.SSM_A, | ||||
|         MODEL_TENSOR.SSM_B_NORM, | ||||
|         MODEL_TENSOR.SSM_C_NORM, | ||||
|         MODEL_TENSOR.SSM_D, | ||||
|         MODEL_TENSOR.SSM_OUT, | ||||
|         MODEL_TENSOR.FFN_GATE_INP, | ||||
|         MODEL_TENSOR.FFN_NORM, | ||||
|         MODEL_TENSOR.FFN_GATE, | ||||
|         MODEL_TENSOR.FFN_DOWN, | ||||
|         MODEL_TENSOR.FFN_UP, | ||||
|         MODEL_TENSOR.FFN_GATE_EXP, | ||||
|         MODEL_TENSOR.FFN_DOWN_EXP, | ||||
|         MODEL_TENSOR.FFN_UP_EXP, | ||||
|     ], | ||||
|     MODEL_ARCH.XVERSE: [ | ||||
|         MODEL_TENSOR.TOKEN_EMBD, | ||||
|         MODEL_TENSOR.OUTPUT_NORM, | ||||
|   | ||||
| @@ -385,8 +385,11 @@ class GGUFWriter: | ||||
|     def add_head_count(self, count: int) -> None: | ||||
|         self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) | ||||
|  | ||||
|     def add_head_count_kv(self, count: int) -> None: | ||||
|         self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) | ||||
|     def add_head_count_kv(self, count: int | Sequence[int]) -> None: | ||||
|         if isinstance(count, int): | ||||
|             self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) | ||||
|         else: | ||||
|             self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) | ||||
|  | ||||
|     def add_key_length(self, length: int) -> None: | ||||
|         self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) | ||||
|   | ||||
| @@ -206,6 +206,7 @@ class TensorNameMap: | ||||
|             "h.{bid}.ln_2",                                                  # gpt2 | ||||
|             "model.layers.{bid}.ffn_norm",                                   # internlm2 | ||||
|             "transformer.decoder_layer.{bid}.rms_norm_2",                    # Grok | ||||
|             "model.layers.{bid}.pre_ff_layernorm",                           # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.FFN_GATE_INP: ( | ||||
| @@ -214,6 +215,7 @@ class TensorNameMap: | ||||
|             "model.layers.{bid}.mlp.gate",                # qwen2moe | ||||
|             "transformer.decoder_layer.{bid}.router",     # Grok | ||||
|             "transformer.blocks.{bid}.ffn.router.layer",  # dbrx | ||||
|             "model.layers.{bid}.feed_forward.router",     # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( | ||||
| @@ -244,6 +246,7 @@ class TensorNameMap: | ||||
|             "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert | ||||
|             "model.layers.{bid}.mlp.c_fc",                            # starcoder2 | ||||
|             "encoder.layer.{bid}.mlp.gated_layers_v",                 # jina-bert-v2 | ||||
|             "model.layers.{bid}.feed_forward.up_proj",                # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.FFN_UP_EXP: ( | ||||
| @@ -272,6 +275,7 @@ class TensorNameMap: | ||||
|             "encoder.layers.{bid}.mlp.fc12",              # nomic-bert | ||||
|             "encoder.layer.{bid}.mlp.gated_layers_w",     # jina-bert-v2 | ||||
|             "transformer.h.{bid}.mlp.linear_1",           # refact | ||||
|             "model.layers.{bid}.feed_forward.gate_proj",  # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.FFN_GATE_EXP: ( | ||||
| @@ -306,6 +310,7 @@ class TensorNameMap: | ||||
|             "encoder.layers.{bid}.mlp.fc2",                           # nomic-bert | ||||
|             "model.layers.{bid}.mlp.c_proj",                          # starcoder2 | ||||
|             "encoder.layer.{bid}.mlp.wo",                             # jina-bert-v2 | ||||
|             "model.layers.{bid}.feed_forward.down_proj",              # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.FFN_DOWN_EXP: ( | ||||
| @@ -347,38 +352,57 @@ class TensorNameMap: | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_IN: ( | ||||
|             "model.layers.{bid}.in_proj", | ||||
|             "backbone.layers.{bid}.mixer.in_proj", | ||||
|             "model.layers.{bid}.in_proj",           # mamba-hf | ||||
|             "backbone.layers.{bid}.mixer.in_proj",  # mamba | ||||
|             "model.layers.{bid}.mamba.in_proj",     # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_CONV1D: ( | ||||
|             "model.layers.{bid}.conv1d", | ||||
|             "backbone.layers.{bid}.mixer.conv1d", | ||||
|             "model.layers.{bid}.conv1d",           # mamba-hf | ||||
|             "backbone.layers.{bid}.mixer.conv1d",  # mamba | ||||
|             "model.layers.{bid}.mamba.conv1d",     # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_X: ( | ||||
|             "model.layers.{bid}.x_proj", | ||||
|             "backbone.layers.{bid}.mixer.x_proj", | ||||
|             "model.layers.{bid}.x_proj",           # mamba-hf | ||||
|             "backbone.layers.{bid}.mixer.x_proj",  # mamba | ||||
|             "model.layers.{bid}.mamba.x_proj",     # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_DT: ( | ||||
|             "model.layers.{bid}.dt_proj", | ||||
|             "backbone.layers.{bid}.mixer.dt_proj", | ||||
|             "model.layers.{bid}.dt_proj",           # mamba-hf | ||||
|             "backbone.layers.{bid}.mixer.dt_proj",  # mamba | ||||
|             "model.layers.{bid}.mamba.dt_proj",     # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_DT_NORM: ( | ||||
|             "model.layers.{bid}.mamba.dt_layernorm",  # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_A: ( | ||||
|             "model.layers.{bid}.A_log", | ||||
|             "backbone.layers.{bid}.mixer.A_log", | ||||
|             "model.layers.{bid}.A_log",           # mamba-hf | ||||
|             "backbone.layers.{bid}.mixer.A_log",  # mamba | ||||
|             "model.layers.{bid}.mamba.A_log",     # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_B_NORM: ( | ||||
|             "model.layers.{bid}.mamba.b_layernorm",  # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_C_NORM: ( | ||||
|             "model.layers.{bid}.mamba.c_layernorm",  # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_D: ( | ||||
|             "model.layers.{bid}.D", | ||||
|             "backbone.layers.{bid}.mixer.D", | ||||
|             "model.layers.{bid}.D",           # mamba-hf | ||||
|             "backbone.layers.{bid}.mixer.D",  # mamba | ||||
|             "model.layers.{bid}.mamba.D",     # jamba | ||||
|         ), | ||||
|  | ||||
|         MODEL_TENSOR.SSM_OUT: ( | ||||
|             "model.layers.{bid}.out_proj", | ||||
|             "backbone.layers.{bid}.mixer.out_proj", | ||||
|             "model.layers.{bid}.out_proj",           # mamba-hf | ||||
|             "backbone.layers.{bid}.mixer.out_proj",  # mamba | ||||
|             "model.layers.{bid}.mamba.out_proj",     # jamba | ||||
|         ), | ||||
|     } | ||||
|  | ||||
|   | ||||
							
								
								
									
										531
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										531
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -221,6 +221,7 @@ enum llm_arch { | ||||
|     LLM_ARCH_GEMMA, | ||||
|     LLM_ARCH_STARCODER2, | ||||
|     LLM_ARCH_MAMBA, | ||||
|     LLM_ARCH_JAMBA, | ||||
|     LLM_ARCH_XVERSE, | ||||
|     LLM_ARCH_COMMAND_R, | ||||
|     LLM_ARCH_DBRX, | ||||
| @@ -257,6 +258,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { | ||||
|     { LLM_ARCH_GEMMA,           "gemma"        }, | ||||
|     { LLM_ARCH_STARCODER2,      "starcoder2"   }, | ||||
|     { LLM_ARCH_MAMBA,           "mamba"        }, | ||||
|     { LLM_ARCH_JAMBA,           "jamba"        }, | ||||
|     { LLM_ARCH_XVERSE,          "xverse"       }, | ||||
|     { LLM_ARCH_COMMAND_R,       "command-r"    }, | ||||
|     { LLM_ARCH_DBRX,            "dbrx"         }, | ||||
| @@ -472,7 +474,10 @@ enum llm_tensor { | ||||
|     LLM_TENSOR_SSM_CONV1D, | ||||
|     LLM_TENSOR_SSM_X, | ||||
|     LLM_TENSOR_SSM_DT, | ||||
|     LLM_TENSOR_SSM_DT_NORM, | ||||
|     LLM_TENSOR_SSM_A, | ||||
|     LLM_TENSOR_SSM_B_NORM, | ||||
|     LLM_TENSOR_SSM_C_NORM, | ||||
|     LLM_TENSOR_SSM_D, | ||||
|     LLM_TENSOR_SSM_OUT, | ||||
| }; | ||||
| @@ -970,6 +975,37 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA | ||||
|             { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_JAMBA, | ||||
|         { | ||||
|             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" }, | ||||
|             { LLM_TENSOR_OUTPUT_NORM,     "output_norm" }, | ||||
|             { LLM_TENSOR_OUTPUT,          "output" }, | ||||
|             { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" }, | ||||
|             { LLM_TENSOR_SSM_IN,          "blk.%d.ssm_in" }, | ||||
|             { LLM_TENSOR_SSM_CONV1D,      "blk.%d.ssm_conv1d" }, | ||||
|             { LLM_TENSOR_SSM_X,           "blk.%d.ssm_x" }, | ||||
|             { LLM_TENSOR_SSM_DT,          "blk.%d.ssm_dt" }, | ||||
|             { LLM_TENSOR_SSM_DT_NORM,     "blk.%d.ssm_dt_norm" }, | ||||
|             { LLM_TENSOR_SSM_A,           "blk.%d.ssm_a" }, | ||||
|             { LLM_TENSOR_SSM_B_NORM,      "blk.%d.ssm_b_norm" }, | ||||
|             { LLM_TENSOR_SSM_C_NORM,      "blk.%d.ssm_c_norm" }, | ||||
|             { LLM_TENSOR_SSM_D,           "blk.%d.ssm_d" }, | ||||
|             { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" }, | ||||
|             { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" }, | ||||
|             { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" }, | ||||
|             { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" }, | ||||
|             { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" }, | ||||
|             { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" }, | ||||
|             { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" }, | ||||
|             { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" }, | ||||
|             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" }, | ||||
|             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" }, | ||||
|             { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" }, | ||||
|             { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" }, | ||||
|             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_XVERSE, | ||||
|         { | ||||
| @@ -1905,6 +1941,9 @@ struct llama_layer { | ||||
|     struct ggml_tensor * attn_k_norm_b; | ||||
|     struct ggml_tensor * attn_out_norm; | ||||
|     struct ggml_tensor * attn_out_norm_b; | ||||
|     struct ggml_tensor * ssm_dt_norm; | ||||
|     struct ggml_tensor * ssm_b_norm; | ||||
|     struct ggml_tensor * ssm_c_norm; | ||||
|  | ||||
|     // attention | ||||
|     struct ggml_tensor * wq; | ||||
| @@ -5150,6 +5189,22 @@ static void llm_load_hparams( | ||||
|                     default: model.type = e_model::MODEL_UNKNOWN; | ||||
|                 } | ||||
|             } break; | ||||
|         case LLM_ARCH_JAMBA: | ||||
|             { | ||||
|                 ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv); | ||||
|                 ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner); | ||||
|                 ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state); | ||||
|                 ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); | ||||
|  | ||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||
|  | ||||
|                 switch (hparams.n_layer) { | ||||
|                     // TODO: Jamba layers are a bit heterogenous, so naming this is hard. | ||||
|                     case 12: // 900M  8x???M | ||||
|                     case 32: // 51B  16x?B | ||||
|                     default: model.type = e_model::MODEL_UNKNOWN; | ||||
|                 } | ||||
|             } break; | ||||
|         case LLM_ARCH_XVERSE: | ||||
|             { | ||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||
| @@ -6854,6 +6909,118 @@ static bool llm_load_tensors( | ||||
|                         layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); | ||||
|                     } | ||||
|                 } break; | ||||
|             case LLM_ARCH_JAMBA: | ||||
|                 { | ||||
|                     const int64_t d_conv  = hparams.ssm_d_conv; | ||||
|                     const int64_t d_inner = hparams.ssm_d_inner; | ||||
|                     const int64_t d_state = hparams.ssm_d_state; | ||||
|                     const int64_t dt_rank = hparams.ssm_dt_rank; | ||||
|  | ||||
|                     // only an expansion factor of 2 is supported for now | ||||
|                     GGML_ASSERT(2 * n_embd == d_inner); | ||||
|                     GGML_ASSERT((int64_t) hparams.n_head_kv_vec.size() == n_layer); | ||||
|  | ||||
|                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); | ||||
|  | ||||
|                     // output | ||||
|                     { | ||||
|                         model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); | ||||
|  | ||||
|                         model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|                         // if output is NULL, init from the input tok embed, duplicated to allow offloading | ||||
|                         if (model.output == NULL) { | ||||
|                             model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     for (int i = 0; i < n_layer; ++i) { | ||||
|                         const int64_t n_head_kv = hparams.n_head_kv_vec[i]; | ||||
|                         const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); | ||||
|  | ||||
|                         ggml_context * ctx_layer = ctx_for_layer(i); | ||||
|                         ggml_context * ctx_split = ctx_for_layer_split(i); | ||||
|  | ||||
|                         auto & layer = model.layers[i]; | ||||
|  | ||||
|                         // norm | ||||
|                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); | ||||
|  | ||||
|                         if (n_head_kv == 0) { | ||||
|                             // Mamba layer | ||||
|                             layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); | ||||
|  | ||||
|                             layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); | ||||
|                             layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); | ||||
|  | ||||
|                             layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); | ||||
|  | ||||
|                             layer.ssm_dt_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}); | ||||
|  | ||||
|                             layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); | ||||
|                             layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); | ||||
|  | ||||
|                             layer.ssm_b_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}); | ||||
|                             layer.ssm_c_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}); | ||||
|  | ||||
|                             // no "weight" suffix for these | ||||
|                             layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); | ||||
|                             layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); | ||||
|  | ||||
|                             // out_proj | ||||
|                             layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); | ||||
|  | ||||
|                             layer.wq = nullptr; | ||||
|                             layer.wk = nullptr; | ||||
|                             layer.wv = nullptr; | ||||
|                             layer.wo = nullptr; | ||||
|  | ||||
|                         } else { | ||||
|                             // Attention layers | ||||
|  | ||||
|                             layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}); | ||||
|                             layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}); | ||||
|                             layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}); | ||||
|                             layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); | ||||
|  | ||||
|                             layer.ssm_in       = nullptr; | ||||
|                             layer.ssm_conv1d   = nullptr; | ||||
|                             layer.ssm_conv1d_b = nullptr; | ||||
|                             layer.ssm_x        = nullptr; | ||||
|                             layer.ssm_dt_norm  = nullptr; | ||||
|                             layer.ssm_dt       = nullptr; | ||||
|                             layer.ssm_dt_b     = nullptr; | ||||
|                             layer.ssm_b_norm   = nullptr; | ||||
|                             layer.ssm_c_norm   = nullptr; | ||||
|                             layer.ssm_a        = nullptr; | ||||
|                             layer.ssm_d        = nullptr; | ||||
|                             layer.ssm_out      = nullptr; | ||||
|                         } | ||||
|  | ||||
|                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); | ||||
|  | ||||
|                         layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); | ||||
|  | ||||
|                         if (layer.ffn_gate_inp) { | ||||
|                             // MoE | ||||
|                             layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); | ||||
|                             layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); | ||||
|                             layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff, n_expert}); | ||||
|  | ||||
|                             layer.ffn_gate = nullptr; | ||||
|                             layer.ffn_down = nullptr; | ||||
|                             layer.ffn_up   = nullptr; | ||||
|                         } else { | ||||
|                             // FFN (no MoE) | ||||
|                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); | ||||
|                             layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); | ||||
|                             layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}); | ||||
|  | ||||
|                             layer.ffn_gate_exps = nullptr; | ||||
|                             layer.ffn_down_exps = nullptr; | ||||
|                             layer.ffn_up_exps   = nullptr; | ||||
|                         } | ||||
|                     } | ||||
|                 } break; | ||||
|             case LLM_ARCH_XVERSE: | ||||
|                 { | ||||
|                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); | ||||
| @@ -7632,6 +7799,132 @@ static struct ggml_tensor * llm_build_kv( | ||||
|     return cur; | ||||
| } | ||||
|  | ||||
| // TODO: split | ||||
| static struct ggml_tensor * llm_build_mamba( | ||||
|         struct ggml_context * ctx, | ||||
|           const llama_model & model, | ||||
|         const llama_hparams & hparams, | ||||
|        const llama_rs_cache & rs, | ||||
|          struct ggml_cgraph * graph, | ||||
|          struct ggml_tensor * cur, | ||||
|          struct ggml_tensor * state_copy, | ||||
|          struct ggml_tensor * state_mask, | ||||
|          struct ggml_tensor * state_seq, | ||||
|          struct ggml_tensor * w_dt_norm, | ||||
|          struct ggml_tensor * w_b_norm, | ||||
|          struct ggml_tensor * w_c_norm, | ||||
|                     int32_t   n_tokens, | ||||
|                     int32_t   rs_head, | ||||
|                     int32_t   n_rs, | ||||
|          const llm_build_cb & cb, | ||||
|                     int       il) { | ||||
|     const int64_t d_conv  = hparams.ssm_d_conv; | ||||
|     const int64_t d_inner = hparams.ssm_d_inner; | ||||
|     const int64_t d_state = hparams.ssm_d_state; | ||||
|     const int64_t dt_rank = hparams.ssm_dt_rank; | ||||
|  | ||||
|     struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, rs.r_l[il], hparams.n_embd_r(il), rs.size); | ||||
|     struct ggml_tensor * ssm_states  = ggml_reshape_2d(ctx, rs.s_l[il], hparams.n_embd_s(il), rs.size); | ||||
|  | ||||
|     // copy states | ||||
|     { | ||||
|         // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows | ||||
|         // NOTE: assuming the copy destinations are ALL contained in the current batch | ||||
|         // this shrinks the tensors's ne[1] to n_rs | ||||
|         conv_states = ggml_get_rows(ctx, conv_states, state_copy); | ||||
|         ssm_states  = ggml_get_rows(ctx,  ssm_states, state_copy); | ||||
|     } | ||||
|  | ||||
|     // clear states of sequences which are starting at the beginning of this batch | ||||
|     { | ||||
|         conv_states = ggml_mul(ctx, conv_states, state_mask); | ||||
|         ssm_states  = ggml_mul(ctx,  ssm_states, state_mask); | ||||
|     } | ||||
|  | ||||
|     conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); | ||||
|     ssm_states  = ggml_reshape_3d(ctx,  ssm_states,    d_state, d_inner, n_rs); | ||||
|  | ||||
|     // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} | ||||
|     struct ggml_tensor * xz = ggml_mul_mat(ctx, model.layers[il].ssm_in, cur); | ||||
|     // split the above in two | ||||
|     // => {d_inner, n_tokens} | ||||
|     struct ggml_tensor * x = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], 0); | ||||
|     struct ggml_tensor * z = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); | ||||
|  | ||||
|     // conv | ||||
|     { | ||||
|         // Custom operator which is needed only to ease simultaneous sequence processing. | ||||
|         // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, | ||||
|         // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, | ||||
|         // then element-wise multiply that with the conv1d weigth, | ||||
|         // then sum the elements of each row, | ||||
|         // (the last two steps are a dot product over rows (also doable with mul_mat)) | ||||
|         // then permute away the ne[0] dimension, | ||||
|         // and then you're left with the resulting x tensor. | ||||
|         // The new conv_states is the last (d_conv - 1) columns | ||||
|         // of the last 3rd dimensional "layer" of the self-overlapping view. | ||||
|         // For simultaneous sequences, it's more complicated. | ||||
|         struct ggml_tensor * x_conv = ggml_ssm_conv(ctx, conv_states, x, model.layers[il].ssm_conv1d, state_seq); | ||||
|  | ||||
|         // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache | ||||
|         ggml_build_forward_expand(graph, | ||||
|             ggml_cpy(ctx, | ||||
|                 ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), | ||||
|                 ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); | ||||
|  | ||||
|         // extract x from x_conv | ||||
|         x = ggml_view_2d(ctx, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); | ||||
|  | ||||
|         // bias | ||||
|         x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); | ||||
|  | ||||
|         x = ggml_silu(ctx, x); | ||||
|     } | ||||
|  | ||||
|     // ssm | ||||
|     { | ||||
|         // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} | ||||
|         struct ggml_tensor * x_db = ggml_mul_mat(ctx, model.layers[il].ssm_x, x); | ||||
|         // split | ||||
|         struct ggml_tensor * dt = ggml_view_2d(ctx, x_db, dt_rank, n_tokens, x_db->nb[1], 0); | ||||
|         struct ggml_tensor * B  = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); | ||||
|         struct ggml_tensor * C  = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); | ||||
|  | ||||
|         if (w_dt_norm) { dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); } | ||||
|         if (w_b_norm)  { B  = llm_build_norm(ctx, B,  hparams, w_b_norm,  NULL, LLM_NORM_RMS, cb, il); } | ||||
|         if (w_c_norm)  { C  = llm_build_norm(ctx, C,  hparams, w_b_norm,  NULL, LLM_NORM_RMS, cb, il); } | ||||
|  | ||||
|         // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} | ||||
|         dt = ggml_mul_mat(ctx, model.layers[il].ssm_dt, dt); | ||||
|         dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); | ||||
|  | ||||
|         // Custom operator to optimize the parallel associative scan | ||||
|         // as described in the Annex D of the Mamba paper. | ||||
|         // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, | ||||
|         // because only a single tensor can be returned. | ||||
|         struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); | ||||
|  | ||||
|         // store last states (the second part of y_ssm_states) | ||||
|         ggml_build_forward_expand(graph, | ||||
|             ggml_cpy(ctx, | ||||
|                 ggml_view_1d(ctx, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), | ||||
|                 ggml_view_1d(ctx, rs.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); | ||||
|  | ||||
|         struct ggml_tensor * y = ggml_view_2d(ctx, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); | ||||
|  | ||||
|         // TODO: skip computing output for unused tokens | ||||
|  | ||||
|         // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} | ||||
|         y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); | ||||
|         y = ggml_mul(ctx, y, ggml_silu(ctx, z)); | ||||
|  | ||||
|         // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} | ||||
|         cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); | ||||
|     } | ||||
|  | ||||
|     return cur; | ||||
| } | ||||
|  | ||||
| struct llm_build_context { | ||||
|     const llama_model    & model; | ||||
|           llama_context  & lctx; | ||||
| @@ -11024,13 +11317,6 @@ struct llm_build_context { | ||||
|     struct ggml_cgraph * build_mamba() { | ||||
|         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); | ||||
|  | ||||
|         const int64_t d_model = n_embd; | ||||
|         const int64_t d_conv  = hparams.ssm_d_conv; | ||||
|         const int64_t d_inner = hparams.ssm_d_inner; | ||||
|         GGML_ASSERT(2 * d_model == d_inner); | ||||
|         const int64_t d_state = hparams.ssm_d_state; | ||||
|         const int64_t dt_rank = hparams.ssm_dt_rank; | ||||
|  | ||||
|         struct ggml_tensor * cur; | ||||
|         struct ggml_tensor * inpL; | ||||
|  | ||||
| @@ -11042,112 +11328,21 @@ struct llm_build_context { | ||||
|         struct ggml_tensor * state_seq  = build_inp_s_seq(); | ||||
|  | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|             struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size); | ||||
|             struct ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size); | ||||
|  | ||||
|             // copy states | ||||
|             { | ||||
|                 // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows | ||||
|                 // NOTE: assuming the copy destinations are ALL contained in the current batch | ||||
|                 // this shrinks the tensors's ne[1] to n_rs | ||||
|                 conv_states = ggml_get_rows(ctx0, conv_states, state_copy); | ||||
|                 ssm_states  = ggml_get_rows(ctx0,  ssm_states, state_copy); | ||||
|             } | ||||
|  | ||||
|             // clear states of sequences which are starting at the beginning of this batch | ||||
|             { | ||||
|                 conv_states = ggml_mul(ctx0, conv_states, state_mask); | ||||
|                 ssm_states  = ggml_mul(ctx0,  ssm_states, state_mask); | ||||
|             } | ||||
|  | ||||
|             conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_rs); | ||||
|             ssm_states  = ggml_reshape_3d(ctx0,  ssm_states,    d_state, d_inner, n_rs); | ||||
|  | ||||
|             // norm | ||||
|             cur = llm_build_norm(ctx0, inpL, hparams, | ||||
|                     model.layers[il].attn_norm, NULL, | ||||
|                     LLM_NORM_RMS, cb, il); | ||||
|             cb(cur, "attn_norm", il); | ||||
|  | ||||
|             // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} | ||||
|             struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); | ||||
|             // split the above in two | ||||
|             // => {d_inner, n_tokens} | ||||
|             struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); | ||||
|             struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); | ||||
|             cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, | ||||
|                     state_copy, state_mask, state_seq, NULL, NULL, NULL, | ||||
|                     n_tokens, rs_head, n_rs, cb, il); | ||||
|  | ||||
|             // conv | ||||
|             { | ||||
|                 // Custom operator which is needed only to ease simultaneous sequence processing. | ||||
|                 // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, | ||||
|                 // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, | ||||
|                 // then element-wise multiply that with the conv1d weigth, | ||||
|                 // then sum the elements of each row, | ||||
|                 // (the last two steps are a dot product over rows (also doable with mul_mat)) | ||||
|                 // then permute away the ne[0] dimension, | ||||
|                 // and then you're left with the resulting x tensor. | ||||
|                 // The new conv_states is the last (d_conv - 1) columns | ||||
|                 // of the last 3rd dimensional "layer" of the self-overlapping view. | ||||
|                 // For simultaneous sequences, it's more complicated. | ||||
|                 struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); | ||||
|  | ||||
|                 // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache | ||||
|                 ggml_build_forward_expand(gf, | ||||
|                     ggml_cpy(ctx0, | ||||
|                         ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), | ||||
|                         ggml_view_1d(ctx0, rs_self.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); | ||||
|  | ||||
|                 // extract x from x_conv | ||||
|                 x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); | ||||
|  | ||||
|                 // bias | ||||
|                 x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); | ||||
|  | ||||
|                 x = ggml_silu(ctx0, x); | ||||
|             } | ||||
|  | ||||
|             // ssm | ||||
|             { | ||||
|                 // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} | ||||
|                 struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); | ||||
|                 // split | ||||
|                 struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); | ||||
|                 struct ggml_tensor * B  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); | ||||
|                 struct ggml_tensor * C  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); | ||||
|  | ||||
|                 // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} | ||||
|                 dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); | ||||
|                 dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); | ||||
|  | ||||
|                 // Custom operator to optimize the parallel associative scan | ||||
|                 // as described in the Annex D of the Mamba paper. | ||||
|                 // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, | ||||
|                 // because only a single tensor can be returned. | ||||
|                 struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); | ||||
|  | ||||
|                 // store last states (the second part of y_ssm_states) | ||||
|                 ggml_build_forward_expand(gf, | ||||
|                     ggml_cpy(ctx0, | ||||
|                         ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), | ||||
|                         ggml_view_1d(ctx0, rs_self.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); | ||||
|  | ||||
|                 struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); | ||||
|  | ||||
|                 if (il == n_layer - 1) { | ||||
|                     // skip computing output for unused tokens | ||||
|                     struct ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||||
|                     x    = ggml_get_rows(ctx0,    x, inp_out_ids); | ||||
|                     y    = ggml_get_rows(ctx0,    y, inp_out_ids); | ||||
|                     z    = ggml_get_rows(ctx0,    z, inp_out_ids); | ||||
|                     inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); | ||||
|                 } | ||||
|  | ||||
|                 // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} | ||||
|                 y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); | ||||
|                 y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); | ||||
|  | ||||
|                 // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} | ||||
|                 cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); | ||||
|             if (il == n_layer - 1) { | ||||
|                 // skip computing output for unused tokens | ||||
|                 struct ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||||
|                 cur  = ggml_get_rows(ctx0,  cur, inp_out_ids); | ||||
|                 inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); | ||||
|             } | ||||
|  | ||||
|             // residual | ||||
| @@ -11173,6 +11368,125 @@ struct llm_build_context { | ||||
|         return gf; | ||||
|     } | ||||
|  | ||||
|     struct ggml_cgraph * build_jamba() { | ||||
|  | ||||
|         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); | ||||
|  | ||||
|         const int64_t n_embd_head = hparams.n_embd_head_v; | ||||
|  | ||||
|         struct ggml_tensor * cur; | ||||
|         struct ggml_tensor * inpL; | ||||
|  | ||||
|         // {n_embd, n_tokens} | ||||
|         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); | ||||
|  | ||||
|         struct ggml_tensor * state_copy = build_inp_s_copy(); | ||||
|         struct ggml_tensor * state_mask = build_inp_s_mask(); | ||||
|         struct ggml_tensor * state_seq  = build_inp_s_seq(); | ||||
|  | ||||
|         // KQ_mask (mask for 1 head, it will be broadcasted to all heads) | ||||
|         struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); | ||||
|  | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|             const int64_t n_head_kv = hparams.n_head_kv_l(il); | ||||
|  | ||||
|             cur = llm_build_norm(ctx0, inpL, hparams, | ||||
|                     model.layers[il].attn_norm, NULL, | ||||
|                     LLM_NORM_RMS, cb, il); | ||||
|             cb(cur, "attn_norm", il); | ||||
|  | ||||
|             if (n_head_kv == 0) { | ||||
|                 // Mamba | ||||
|                 cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, | ||||
|                         state_copy, state_mask, state_seq, | ||||
|                         model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, | ||||
|                         n_tokens, rs_head, n_rs, cb, il); | ||||
|             } else { | ||||
|                 // Attention | ||||
|  | ||||
|                 struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); | ||||
|                 struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); | ||||
|                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); | ||||
|  | ||||
|                 cb(Qcur, "Qcur", il); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
|                 cb(Vcur, "Vcur", il); | ||||
|  | ||||
|                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens); | ||||
|                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); | ||||
|  | ||||
|                 cb(Qcur, "Qcur", il); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
|  | ||||
|                 // No RoPE :) | ||||
|  | ||||
|                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, | ||||
|                         model.layers[il].wo, NULL, | ||||
|                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); | ||||
|             } | ||||
|  | ||||
|             if (il == n_layer - 1) { | ||||
|                 // skip computing output for unused tokens | ||||
|                 struct ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||||
|                 cur  = ggml_get_rows(ctx0,  cur, inp_out_ids); | ||||
|                 inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); | ||||
|             } | ||||
|  | ||||
|             // residual | ||||
|             struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur); | ||||
|             cb(cur, "ffn_inp", il); | ||||
|  | ||||
|             cur = llm_build_norm(ctx0, ffn_inp, hparams, | ||||
|                     model.layers[il].ffn_norm, NULL, | ||||
|                     LLM_NORM_RMS, cb, il); | ||||
|             cb(cur, "ffn_norm", il); | ||||
|  | ||||
|             // feed-forward network | ||||
|             if (model.layers[il].ffn_gate_inp == nullptr) { | ||||
|                 // FFN | ||||
|                 cur = llm_build_ffn(ctx0, cur, | ||||
|                         model.layers[il].ffn_up,   NULL, | ||||
|                         model.layers[il].ffn_gate, NULL, | ||||
|                         model.layers[il].ffn_down, NULL, | ||||
|                         NULL, | ||||
|                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il); | ||||
|                 cb(cur, "ffn_out", il); | ||||
|             } else { | ||||
|                 // MoE branch | ||||
|                 cur = llm_build_moe_ffn(ctx0, cur, | ||||
|                         model.layers[il].ffn_gate_inp, | ||||
|                         model.layers[il].ffn_up_exps, | ||||
|                         model.layers[il].ffn_gate_exps, | ||||
|                         model.layers[il].ffn_down_exps, | ||||
|                         n_expert, n_expert_used, | ||||
|                         LLM_FFN_SILU, false, | ||||
|                         cb, il); | ||||
|                 cb(cur, "ffn_moe_out", il); | ||||
|             } | ||||
|  | ||||
|             // residual | ||||
|             cur = ggml_add(ctx0, ffn_inp, cur); | ||||
|             cb(cur, "l_out", il); | ||||
|  | ||||
|             // input for next layer | ||||
|             inpL = cur; | ||||
|         } | ||||
|  | ||||
|         // final rmsnorm | ||||
|         cur = llm_build_norm(ctx0, inpL, hparams, | ||||
|                 model.output_norm, NULL, | ||||
|                 LLM_NORM_RMS, cb, -1); | ||||
|         cb(cur, "result_norm", -1); | ||||
|  | ||||
|         // lm_head | ||||
|         cur = ggml_mul_mat(ctx0, model.output, cur); | ||||
|         cb(cur, "result_output", -1); | ||||
|  | ||||
|         ggml_build_forward_expand(gf, cur); | ||||
|  | ||||
|         return gf; | ||||
|     } | ||||
|  | ||||
|     struct ggml_cgraph * build_command_r() { | ||||
|  | ||||
|         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); | ||||
| @@ -11630,6 +11944,10 @@ static struct ggml_cgraph * llama_build_graph( | ||||
|             { | ||||
|                 result = llm.build_mamba(); | ||||
|             } break; | ||||
|         case LLM_ARCH_JAMBA: | ||||
|             { | ||||
|                 result = llm.build_jamba(); | ||||
|             } break; | ||||
|         case LLM_ARCH_XVERSE: | ||||
|             { | ||||
|                 result = llm.build_xverse(); | ||||
| @@ -16644,6 +16962,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { | ||||
|         case LLM_ARCH_REFACT: | ||||
|         case LLM_ARCH_BLOOM: | ||||
|         case LLM_ARCH_MAMBA: | ||||
|         case LLM_ARCH_JAMBA: | ||||
|         case LLM_ARCH_JINA_BERT_V2: | ||||
|             return LLAMA_ROPE_TYPE_NONE; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin