mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : support RWKV v6 models (#8980)
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV v6 graph building Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Remove trailing whitespaces Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update src/llama.cpp Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <git@compilade.net> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Co-authored-by: compilade <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		| @@ -58,17 +58,17 @@ struct naive_trie { | ||||
|         auto res = children.find(c); | ||||
|         if (res != children.end()) { | ||||
|             return res->second.get_longest_prefix(key, len, offset + 1); | ||||
|         } else { | ||||
|             return std::make_pair(key, offset); | ||||
|         } | ||||
|  | ||||
|         return std::make_pair(key, offset); | ||||
|     } | ||||
|     struct naive_trie * traverse(const char c) { | ||||
|     const struct naive_trie * traverse(const char c) const { | ||||
|         auto res = children.find(c); | ||||
|         if (res != children.end()) { | ||||
|             return &res->second; | ||||
|         } else { | ||||
|             return NULL; | ||||
|         } | ||||
|  | ||||
|         return NULL; | ||||
|     } | ||||
|     std::map<char, struct naive_trie> children; | ||||
|     bool has_value; | ||||
| @@ -843,7 +843,7 @@ struct llm_tokenizer_ugm { | ||||
|             // traverse the token matcher trie to find a matching token | ||||
|             bool single_codepoint_token_found = false; | ||||
|             const struct best_tokenization & current_best = tokenization_results[input_offset]; | ||||
|             struct naive_trie * node  = token_matcher.traverse(normalized[prefix_offset++]); | ||||
|             const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); | ||||
|  | ||||
|             while (prefix_offset <= input_len && node != NULL) { | ||||
|                 // check if we found valid token in prefix | ||||
| @@ -1097,6 +1097,111 @@ private: | ||||
|     struct naive_trie token_matcher; | ||||
| }; | ||||
|  | ||||
| // | ||||
| // RWKV tokenizer | ||||
| // | ||||
|  | ||||
| static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) { | ||||
|     std::vector<uint8_t> output; | ||||
|     output.reserve(escaped.size()); | ||||
|  | ||||
|     // Parser state | ||||
|     bool escaping = false; | ||||
|     uint8_t hex_remaining = 0; | ||||
|     uint8_t hex_acc = 0; | ||||
|  | ||||
|     // Step through characters, performing parsing | ||||
|     for (const char & c : escaped) { | ||||
|         // If we're parsing a hex code, interpret the next character | ||||
|         if (hex_remaining != 0) { | ||||
|             uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0'); | ||||
|             hex_acc = (hex_acc << 4) + value; | ||||
|  | ||||
|             hex_remaining -= 1; | ||||
|             if (hex_remaining == 0) { | ||||
|                 output.push_back(hex_acc); | ||||
|                 hex_acc = 0; | ||||
|             } | ||||
|  | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         // If we got an escape character, interpret it | ||||
|         if (escaping) { | ||||
|             if (c == 't') { | ||||
|                 output.push_back('\t'); | ||||
|             } else if (c == 'n') { | ||||
|                 output.push_back('\n'); | ||||
|             } else if (c == 'r') { | ||||
|                 output.push_back('\r'); | ||||
|             } else if (c == 'x') { | ||||
|                 hex_remaining = 2; | ||||
|             } else { | ||||
|                 output.push_back(c); | ||||
|             } | ||||
|  | ||||
|             escaping = false; | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if (c == '\\') { | ||||
|             escaping = true; | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         output.push_back(c); | ||||
|     } | ||||
|  | ||||
|     return output; | ||||
| } | ||||
|  | ||||
| struct llm_tokenizer_rwkv { | ||||
|     llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { | ||||
|         // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. | ||||
|         // For now, we decode the vocab here into the lookup we'll use for tokenization. | ||||
|  | ||||
|         // build trie | ||||
|         for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { | ||||
|             const auto & token = vocab.id_to_token[id]; | ||||
|             const auto data = llama_unescape_rwkv_token(token.text); | ||||
|             token_matcher.insert((const char *) data.data(), data.size(), id); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { | ||||
|         uint32_t position = 0; | ||||
|  | ||||
|         while (position < text.size()) { | ||||
|             const struct naive_trie * node = token_matcher.traverse(text[position]); | ||||
|             if (node == NULL) { | ||||
|                 // no matching token found, add unknown token | ||||
|                 output.push_back(vocab.special_unk_id); | ||||
|                 position += 1; | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             // traverse the trie to find the longest matching token | ||||
|             uint32_t token_id = 0; | ||||
|             uint32_t token_length = 0; | ||||
|             while (node != NULL) { | ||||
|                 if (node->has_value) { | ||||
|                     token_id = node->value; | ||||
|                     token_length = position + 1; | ||||
|                 } | ||||
|                 node = node->traverse(text[++position]); | ||||
|             } | ||||
|  | ||||
|             // add the longest matching token | ||||
|             output.push_back(token_id); | ||||
|             position = token_length; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     const llama_vocab & vocab; | ||||
|  | ||||
|     struct naive_trie token_matcher; | ||||
| }; | ||||
|  | ||||
| // | ||||
| // (de-) tokenize | ||||
| // | ||||
| @@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, | ||||
|                     output.push_back(vocab.special_eos_id); | ||||
|                 } | ||||
|             } break; | ||||
|         case LLAMA_VOCAB_TYPE_RWKV: | ||||
|             { | ||||
|                 for (const auto & fragment : fragment_buffer) { | ||||
|                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { | ||||
|                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); | ||||
|  | ||||
| #ifdef PRETOKENIZERDEBUG | ||||
|                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); | ||||
| #endif | ||||
|  | ||||
|                         llm_tokenizer_rwkv tokenizer(vocab); | ||||
|                         tokenizer.tokenize(raw_text, output); | ||||
|                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) | ||||
|                         output.push_back(fragment.token); | ||||
|                     } | ||||
|                 } | ||||
|             } break; | ||||
|         case LLAMA_VOCAB_TYPE_NONE: | ||||
|             GGML_ABORT("fatal error"); | ||||
|     } | ||||
| @@ -1616,6 +1738,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token | ||||
|                 } | ||||
|                 break; | ||||
|             } | ||||
|             case LLAMA_VOCAB_TYPE_RWKV: { | ||||
|                 std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text); | ||||
|  | ||||
|                 // If we don't have enough space, return an error | ||||
|                 if (result.size() > (size_t)length) { | ||||
|                     return -(int)result.size(); | ||||
|                 } | ||||
|  | ||||
|                 memcpy(buf, result.data(), result.size()); | ||||
|                 return (int)result.size(); | ||||
|             } | ||||
|             default: | ||||
|                 GGML_ABORT("fatal error"); | ||||
|         } | ||||
|   | ||||
							
								
								
									
										533
									
								
								src/llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										533
									
								
								src/llama.cpp
									
									
									
									
									
								
							| @@ -212,6 +212,7 @@ enum llm_arch { | ||||
|     LLM_ARCH_JAIS, | ||||
|     LLM_ARCH_NEMOTRON, | ||||
|     LLM_ARCH_EXAONE, | ||||
|     LLM_ARCH_RWKV6, | ||||
|     LLM_ARCH_UNKNOWN, | ||||
| }; | ||||
|  | ||||
| @@ -259,6 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { | ||||
|     { LLM_ARCH_JAIS,            "jais"         }, | ||||
|     { LLM_ARCH_NEMOTRON,        "nemotron"     }, | ||||
|     { LLM_ARCH_EXAONE,          "exaone"       }, | ||||
|     { LLM_ARCH_RWKV6,           "rwkv6"        }, | ||||
|     { LLM_ARCH_UNKNOWN,         "(unknown)"    }, | ||||
| }; | ||||
|  | ||||
| @@ -295,6 +297,9 @@ enum llm_kv { | ||||
|     LLM_KV_DECODER_START_TOKEN_ID, | ||||
|     LLM_KV_ATTN_LOGIT_SOFTCAPPING, | ||||
|     LLM_KV_FINAL_LOGIT_SOFTCAPPING, | ||||
|     LLM_KV_RESCALE_EVERY_N_LAYERS, | ||||
|     LLM_KV_TIME_MIX_EXTRA_DIM, | ||||
|     LLM_KV_TIME_DECAY_EXTRA_DIM, | ||||
|  | ||||
|     LLM_KV_ATTENTION_HEAD_COUNT, | ||||
|     LLM_KV_ATTENTION_HEAD_COUNT_KV, | ||||
| @@ -330,6 +335,8 @@ enum llm_kv { | ||||
|     LLM_KV_SSM_TIME_STEP_RANK, | ||||
|     LLM_KV_SSM_DT_B_C_RMS, | ||||
|  | ||||
|     LLM_KV_WKV_HEAD_SIZE, | ||||
|  | ||||
|     LLM_KV_TOKENIZER_MODEL, | ||||
|     LLM_KV_TOKENIZER_PRE, | ||||
|     LLM_KV_TOKENIZER_LIST, | ||||
| @@ -389,11 +396,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { | ||||
|     { LLM_KV_EXPERT_USED_COUNT,                 "%s.expert_used_count"                 }, | ||||
|     { LLM_KV_EXPERT_SHARED_COUNT,               "%s.expert_shared_count"               }, | ||||
|     { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              }, | ||||
|     { LLM_KV_POOLING_TYPE ,                     "%s.pooling_type"                      }, | ||||
|     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      }, | ||||
|     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       }, | ||||
|     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            }, | ||||
|     { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            }, | ||||
|     { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           }, | ||||
|     { LLM_KV_RESCALE_EVERY_N_LAYERS,            "%s.rescale_every_n_layers"            }, | ||||
|     { LLM_KV_TIME_MIX_EXTRA_DIM,                "%s.time_mix_extra_dim"                }, | ||||
|     { LLM_KV_TIME_DECAY_EXTRA_DIM,              "%s.time_decay_extra_dim"              }, | ||||
|  | ||||
|     { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             }, | ||||
|     { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          }, | ||||
| @@ -429,6 +439,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { | ||||
|     { LLM_KV_SSM_TIME_STEP_RANK,            "%s.ssm.time_step_rank" }, | ||||
|     { LLM_KV_SSM_DT_B_C_RMS,                "%s.ssm.dt_b_c_rms" }, | ||||
|  | ||||
|     { LLM_KV_WKV_HEAD_SIZE,                 "%s.wkv.head_size" }, | ||||
|  | ||||
|     { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    }, | ||||
|     { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      }, | ||||
|     { LLM_KV_TOKENIZER_LIST,                 "tokenizer.ggml.tokens"                   }, | ||||
| @@ -518,6 +530,29 @@ enum llm_tensor { | ||||
|     LLM_TENSOR_SSM_A, | ||||
|     LLM_TENSOR_SSM_D, | ||||
|     LLM_TENSOR_SSM_OUT, | ||||
|     LLM_TENSOR_TIME_MIX_W1, | ||||
|     LLM_TENSOR_TIME_MIX_W2, | ||||
|     LLM_TENSOR_TIME_MIX_LERP_X, | ||||
|     LLM_TENSOR_TIME_MIX_LERP_W, | ||||
|     LLM_TENSOR_TIME_MIX_LERP_K, | ||||
|     LLM_TENSOR_TIME_MIX_LERP_V, | ||||
|     LLM_TENSOR_TIME_MIX_LERP_R, | ||||
|     LLM_TENSOR_TIME_MIX_LERP_G, | ||||
|     LLM_TENSOR_TIME_MIX_FIRST, | ||||
|     LLM_TENSOR_TIME_MIX_DECAY, | ||||
|     LLM_TENSOR_TIME_MIX_DECAY_W1, | ||||
|     LLM_TENSOR_TIME_MIX_DECAY_W2, | ||||
|     LLM_TENSOR_TIME_MIX_KEY, | ||||
|     LLM_TENSOR_TIME_MIX_VALUE, | ||||
|     LLM_TENSOR_TIME_MIX_RECEPTANCE, | ||||
|     LLM_TENSOR_TIME_MIX_GATE, | ||||
|     LLM_TENSOR_TIME_MIX_LN, | ||||
|     LLM_TENSOR_TIME_MIX_OUTPUT, | ||||
|     LLM_TENSOR_CHANNEL_MIX_LERP_K, | ||||
|     LLM_TENSOR_CHANNEL_MIX_LERP_R, | ||||
|     LLM_TENSOR_CHANNEL_MIX_KEY, | ||||
|     LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, | ||||
|     LLM_TENSOR_CHANNEL_MIX_VALUE, | ||||
|     LLM_TENSOR_ATTN_Q_A, | ||||
|     LLM_TENSOR_ATTN_Q_B, | ||||
|     LLM_TENSOR_ATTN_KV_A_MQA, | ||||
| @@ -1339,6 +1374,40 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA | ||||
|             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_RWKV6, | ||||
|         { | ||||
|             { LLM_TENSOR_TOKEN_EMBD,                "token_embd" }, | ||||
|             { LLM_TENSOR_TOKEN_EMBD_NORM,           "token_embd_norm" }, | ||||
|             { LLM_TENSOR_OUTPUT_NORM,               "output_norm" }, | ||||
|             { LLM_TENSOR_OUTPUT,                    "output" }, | ||||
|             { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" }, | ||||
|             { LLM_TENSOR_ATTN_NORM_2,               "blk.%d.attn_norm_2" }, | ||||
|             { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" }, | ||||
|             { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" }, | ||||
|             { LLM_TENSOR_TIME_MIX_LERP_X,           "blk.%d.time_mix_lerp_x" }, | ||||
|             { LLM_TENSOR_TIME_MIX_LERP_W,           "blk.%d.time_mix_lerp_w" }, | ||||
|             { LLM_TENSOR_TIME_MIX_LERP_K,           "blk.%d.time_mix_lerp_k" }, | ||||
|             { LLM_TENSOR_TIME_MIX_LERP_V,           "blk.%d.time_mix_lerp_v" }, | ||||
|             { LLM_TENSOR_TIME_MIX_LERP_R,           "blk.%d.time_mix_lerp_r" }, | ||||
|             { LLM_TENSOR_TIME_MIX_LERP_G,           "blk.%d.time_mix_lerp_g" }, | ||||
|             { LLM_TENSOR_TIME_MIX_FIRST,            "blk.%d.time_mix_first" }, | ||||
|             { LLM_TENSOR_TIME_MIX_DECAY,            "blk.%d.time_mix_decay" }, | ||||
|             { LLM_TENSOR_TIME_MIX_DECAY_W1,         "blk.%d.time_mix_decay_w1" }, | ||||
|             { LLM_TENSOR_TIME_MIX_DECAY_W2,         "blk.%d.time_mix_decay_w2" }, | ||||
|             { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" }, | ||||
|             { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" }, | ||||
|             { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" }, | ||||
|             { LLM_TENSOR_TIME_MIX_GATE,             "blk.%d.time_mix_gate" }, | ||||
|             { LLM_TENSOR_TIME_MIX_LN,               "blk.%d.time_mix_ln" }, | ||||
|             { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" }, | ||||
|             { LLM_TENSOR_CHANNEL_MIX_LERP_K,        "blk.%d.channel_mix_lerp_k" }, | ||||
|             { LLM_TENSOR_CHANNEL_MIX_LERP_R,        "blk.%d.channel_mix_lerp_r" }, | ||||
|             { LLM_TENSOR_CHANNEL_MIX_KEY,           "blk.%d.channel_mix_key" }, | ||||
|             { LLM_TENSOR_CHANNEL_MIX_VALUE,         "blk.%d.channel_mix_value" }, | ||||
|             { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,    "blk.%d.channel_mix_receptance" }, | ||||
|         }, | ||||
|     }, | ||||
|     { | ||||
|         LLM_ARCH_UNKNOWN, | ||||
|         { | ||||
| @@ -2151,6 +2220,7 @@ enum e_model { | ||||
|     MODEL_1B, | ||||
|     MODEL_1_3B, | ||||
|     MODEL_1_4B, | ||||
|     MODEL_1_6B, | ||||
|     MODEL_2B, | ||||
|     MODEL_2_8B, | ||||
|     MODEL_3B, | ||||
| @@ -2228,6 +2298,12 @@ struct llama_hparams { | ||||
|     float f_attn_logit_softcapping = 50.0f; | ||||
|     float f_final_logit_softcapping = 30.0f; | ||||
|  | ||||
|     // for RWKV | ||||
|     uint32_t rescale_every_n_layers = 0; | ||||
|     uint32_t time_mix_extra_dim = 0; | ||||
|     uint32_t time_decay_extra_dim = 0; | ||||
|     uint32_t wkv_head_size = 0; | ||||
|  | ||||
|     float    rope_attn_factor = 1.0f; | ||||
|     float    rope_freq_base_train; | ||||
|     float    rope_freq_scale_train; | ||||
| @@ -2291,6 +2367,11 @@ struct llama_hparams { | ||||
|         if (this->ssm_dt_rank != other.ssm_dt_rank) return true; | ||||
|         if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; | ||||
|  | ||||
|         if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; | ||||
|         if (this->time_mix_extra_dim     != other.time_mix_extra_dim)     return true; | ||||
|         if (this->time_decay_extra_dim   != other.time_decay_extra_dim)   return true; | ||||
|         if (this->wkv_head_size          != other.wkv_head_size)          return true; | ||||
|  | ||||
|         if (this->dec_start_token_id != other.dec_start_token_id) return true; | ||||
|  | ||||
|         const float EPSILON = 1e-9f; | ||||
| @@ -2354,15 +2435,25 @@ struct llama_hparams { | ||||
|     } | ||||
|  | ||||
|     uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings | ||||
|         // corresponds to Mamba's conv_states size | ||||
|         // TODO: maybe support other convolution strides than 1 | ||||
|         // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed | ||||
|         return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; | ||||
|         // corresponds to Mamba's conv_states size or RWKV's token_shift states size | ||||
|         if (wkv_head_size != 0) { | ||||
|             // for RWKV models | ||||
|             return 2 * n_embd; | ||||
|         } else { | ||||
|             // TODO: maybe support other convolution strides than 1 | ||||
|             // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed | ||||
|             return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings | ||||
|         // corresponds to Mamba's ssm_states size | ||||
|         return ssm_d_state * ssm_d_inner; | ||||
|         if (wkv_head_size != 0) { | ||||
|             // corresponds to RWKV's wkv_states size | ||||
|             return n_embd * wkv_head_size; | ||||
|         } else { | ||||
|             // corresponds to Mamba's ssm_states size | ||||
|             return ssm_d_state * ssm_d_inner; | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| @@ -2501,6 +2592,36 @@ struct llama_layer { | ||||
|     struct ggml_tensor * ssm_conv1d_b; | ||||
|     struct ggml_tensor * ssm_dt_b; | ||||
|  | ||||
|     // rwkv | ||||
|     struct ggml_tensor * time_mix_w1; | ||||
|     struct ggml_tensor * time_mix_w2; | ||||
|     struct ggml_tensor * time_mix_lerp_x; | ||||
|     struct ggml_tensor * time_mix_lerp_w; | ||||
|     struct ggml_tensor * time_mix_lerp_k; | ||||
|     struct ggml_tensor * time_mix_lerp_v; | ||||
|     struct ggml_tensor * time_mix_lerp_r; | ||||
|     struct ggml_tensor * time_mix_lerp_g; | ||||
|  | ||||
|     struct ggml_tensor * time_mix_first; | ||||
|     struct ggml_tensor * time_mix_decay; | ||||
|     struct ggml_tensor * time_mix_decay_w1; | ||||
|     struct ggml_tensor * time_mix_decay_w2; | ||||
|     struct ggml_tensor * time_mix_key; | ||||
|     struct ggml_tensor * time_mix_value; | ||||
|     struct ggml_tensor * time_mix_receptance; | ||||
|     struct ggml_tensor * time_mix_gate; | ||||
|  | ||||
|     struct ggml_tensor * time_mix_ln; | ||||
|     struct ggml_tensor * time_mix_ln_b; | ||||
|     struct ggml_tensor * time_mix_output; | ||||
|  | ||||
|     struct ggml_tensor * channel_mix_lerp_k; | ||||
|     struct ggml_tensor * channel_mix_lerp_r; | ||||
|  | ||||
|     struct ggml_tensor * channel_mix_key; | ||||
|     struct ggml_tensor * channel_mix_receptance; | ||||
|     struct ggml_tensor * channel_mix_value; | ||||
|  | ||||
|     // long rope factors | ||||
|     struct ggml_tensor * rope_long  = nullptr; | ||||
|     struct ggml_tensor * rope_short = nullptr; | ||||
| @@ -3426,7 +3547,7 @@ static bool llama_kv_cache_find_slot( | ||||
|     const uint32_t n_seq_tokens = batch.n_seq_tokens; | ||||
|  | ||||
|     if (cache.recurrent) { | ||||
|         // For recurrent state architectures (like Mamba), | ||||
|         // For recurrent state architectures (like Mamba or RWKV), | ||||
|         // each cache cell can store the state for a whole sequence. | ||||
|         // A slot should be always be contiguous. | ||||
|  | ||||
| @@ -3675,7 +3796,7 @@ static bool llama_kv_cache_seq_rm( | ||||
|     if (p0 < 0) p0 = 0; | ||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     // models like Mamba can't have a state partially erased | ||||
|     // models like Mamba or RWKV can't have a state partially erased | ||||
|     if (cache.recurrent) { | ||||
|         if (seq_id >= (int64_t) cache.size) { | ||||
|             // could be fatal | ||||
| @@ -3811,7 +3932,7 @@ static void llama_kv_cache_seq_add( | ||||
|     if (p0 == p1) return; | ||||
|  | ||||
|     if (cache.recurrent) { | ||||
|         // for Mamba-like models, only the pos needs to be shifted | ||||
|         // for Mamba-like or RWKV models, only the pos needs to be shifted | ||||
|         if (0 <= seq_id && seq_id < (int64_t) cache.size) { | ||||
|             const int32_t tail_id = cache.cells[seq_id].tail; | ||||
|             if (tail_id >= 0) { | ||||
| @@ -3860,7 +3981,7 @@ static void llama_kv_cache_seq_div( | ||||
|     if (p0 == p1) return; | ||||
|  | ||||
|     if (cache.recurrent) { | ||||
|         // for Mamba-like models, only the pos needs to be changed | ||||
|         // for Mamba-like or RWKV models, only the pos needs to be changed | ||||
|         if (0 <= seq_id && seq_id < (int64_t) cache.size) { | ||||
|             const int32_t tail_id = cache.cells[seq_id].tail; | ||||
|             if (tail_id >= 0) { | ||||
| @@ -5051,6 +5172,7 @@ static const char * llama_model_type_name(e_model type) { | ||||
|         case MODEL_1B:            return "1B"; | ||||
|         case MODEL_1_3B:          return "1.3B"; | ||||
|         case MODEL_1_4B:          return "1.4B"; | ||||
|         case MODEL_1_6B:          return "1.6B"; | ||||
|         case MODEL_2B:            return "2B"; | ||||
|         case MODEL_2_8B:          return "2.8B"; | ||||
|         case MODEL_3B:            return "3B"; | ||||
| @@ -5097,6 +5219,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ | ||||
|         case LLAMA_VOCAB_TYPE_BPE:  return "BPE"; | ||||
|         case LLAMA_VOCAB_TYPE_WPM:  return "WPM"; | ||||
|         case LLAMA_VOCAB_TYPE_UGM:  return "UGM"; | ||||
|         case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; | ||||
|         default:                    return "unknown"; | ||||
|     } | ||||
| } | ||||
| @@ -5793,6 +5916,26 @@ static void llm_load_hparams( | ||||
|                     default: model.type = e_model::MODEL_UNKNOWN; | ||||
|                 } | ||||
|             } break; | ||||
|         case LLM_ARCH_RWKV6: | ||||
|             { | ||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); | ||||
|                 ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); | ||||
|                 ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); | ||||
|                 ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); | ||||
|                 ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); | ||||
|  | ||||
|                 switch (hparams.n_layer) { | ||||
|                     case 24: model.type = e_model::MODEL_1_6B; break; | ||||
|                     case 32: | ||||
|                         switch (hparams.n_embd) { | ||||
|                             case 2560: model.type = e_model::MODEL_3B; break; | ||||
|                             case 4096: model.type = e_model::MODEL_7B; break; | ||||
|                             default: model.type = e_model::MODEL_UNKNOWN; | ||||
|                         } break; | ||||
|                     case 61: model.type = e_model::MODEL_14B; break; | ||||
|                     default: model.type = e_model::MODEL_UNKNOWN; | ||||
|                 } | ||||
|             } break; | ||||
|         default: (void)0; | ||||
|     } | ||||
|  | ||||
| @@ -5922,6 +6065,15 @@ static void llm_load_vocab( | ||||
|                 } | ||||
| #endif | ||||
|             } | ||||
|         } else if (tokenizer_model == "rwkv") { | ||||
|             vocab.type = LLAMA_VOCAB_TYPE_RWKV; | ||||
|  | ||||
|             // default special tokens | ||||
|             vocab.special_bos_id = -1; | ||||
|             vocab.special_eos_id = -1; | ||||
|             vocab.special_unk_id = -1; | ||||
|             vocab.special_sep_id = -1; | ||||
|             vocab.special_pad_id = -1; | ||||
|         } else { | ||||
|             throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); | ||||
|         } | ||||
| @@ -6053,6 +6205,12 @@ static void llm_load_vocab( | ||||
|             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; | ||||
|             vocab.tokenizer_add_bos = false; | ||||
|             vocab.tokenizer_add_eos = true; | ||||
|         } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { | ||||
|             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; | ||||
|             vocab.tokenizer_add_space_prefix = false; | ||||
|             vocab.tokenizer_clean_spaces = false; | ||||
|             vocab.tokenizer_add_bos = false; | ||||
|             vocab.tokenizer_add_eos = false; | ||||
|         } else { | ||||
|             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; | ||||
|         } | ||||
| @@ -6157,6 +6315,10 @@ static void llm_load_vocab( | ||||
|         } | ||||
|     } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { | ||||
|         vocab.linefeed_id = vocab.special_pad_id; | ||||
|     } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { | ||||
|         const std::vector<int> ids = llama_tokenize_internal(vocab, "\n", false); | ||||
|         GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); | ||||
|         vocab.linefeed_id = ids[0]; | ||||
|     } else { | ||||
|         const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A | ||||
|         GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); | ||||
| @@ -8203,6 +8365,68 @@ static bool llm_load_tensors( | ||||
|                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}); | ||||
|                     } | ||||
|                 } break; | ||||
|             case LLM_ARCH_RWKV6: | ||||
|                 { | ||||
|                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); | ||||
|  | ||||
|                     // Block 0, LN0 | ||||
|                     model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); | ||||
|                     model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); | ||||
|  | ||||
|                     // output | ||||
|                     model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); | ||||
|                     model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); | ||||
|                     model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); | ||||
|  | ||||
|                     const int time_mix_extra_dim = hparams.time_mix_extra_dim; | ||||
|                     const int time_decay_extra_dim = hparams.time_decay_extra_dim; | ||||
|                     const int head_size = hparams.wkv_head_size; | ||||
|                     const int attn_hidden_size = n_embd; | ||||
|                     const int ffn_size = hparams.n_ff_arr[0]; | ||||
|  | ||||
|                     for (int i = 0; i < n_layer; ++i) { | ||||
|                         ggml_context * ctx_layer = ctx_for_layer(i); | ||||
|  | ||||
|                         auto & layer = model.layers[i]; | ||||
|  | ||||
|                         layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); | ||||
|                         layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}); | ||||
|  | ||||
|                         layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}); | ||||
|                         layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}); | ||||
|  | ||||
|                         layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}); | ||||
|                         layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}); | ||||
|  | ||||
|                         layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}); | ||||
|                         layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}); | ||||
|                         layer.time_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}); | ||||
|                         layer.time_mix_lerp_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}); | ||||
|                         layer.time_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}); | ||||
|                         layer.time_mix_lerp_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}); | ||||
|  | ||||
|                         layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}); | ||||
|                         layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}); | ||||
|                         layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}); | ||||
|                         layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}); | ||||
|                         layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}); | ||||
|                         layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}); | ||||
|                         layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}); | ||||
|                         layer.time_mix_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}); | ||||
|  | ||||
|                         layer.time_mix_ln = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}); | ||||
|                         layer.time_mix_ln_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}); | ||||
|                         layer.time_mix_output = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}); | ||||
|  | ||||
|                         layer.channel_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}); | ||||
|                         layer.channel_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}); | ||||
|  | ||||
|                         layer.channel_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}); | ||||
|                         layer.channel_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}); | ||||
|                         layer.channel_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}); | ||||
|                     } | ||||
|  | ||||
|                 } break; | ||||
|             default: | ||||
|                 throw std::runtime_error("unknown architecture"); | ||||
|         } | ||||
| @@ -9162,6 +9386,171 @@ static struct ggml_tensor * llm_build_mamba( | ||||
|     return cur; | ||||
| } | ||||
|  | ||||
| static struct ggml_tensor * llm_build_rwkv6_time_mix( | ||||
|         struct llama_context & lctx, | ||||
|         struct ggml_context * ctx, | ||||
|         const struct llama_layer * layer, | ||||
|         struct ggml_tensor * cur, | ||||
|         struct ggml_tensor * x_prev, | ||||
|         struct ggml_tensor ** wkv_state) { | ||||
|     size_t n_embed      = cur->ne[0]; | ||||
|     size_t n_seq_tokens = cur->ne[1]; | ||||
|     size_t n_seqs       = cur->ne[2]; | ||||
|  | ||||
|     size_t head_size  = layer->time_mix_first->ne[0]; | ||||
|     size_t head_count = layer->time_mix_first->ne[1]; | ||||
|  | ||||
|     size_t n_tokens = n_seqs * n_seq_tokens; | ||||
|  | ||||
|     struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); | ||||
|  | ||||
|     sx  = ggml_reshape_2d(ctx, sx,  n_embed, n_tokens); | ||||
|     cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens); | ||||
|  | ||||
|     struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur); | ||||
|  | ||||
|     xxx = ggml_reshape_4d( | ||||
|         ctx, | ||||
|         ggml_tanh( | ||||
|             ctx, | ||||
|             ggml_mul_mat(ctx, layer->time_mix_w1, xxx) | ||||
|         ), | ||||
|         layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens | ||||
|     ); | ||||
|  | ||||
|     xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2)); | ||||
|  | ||||
|     xxx = ggml_mul_mat( | ||||
|         ctx, | ||||
|         ggml_reshape_4d( | ||||
|             ctx, | ||||
|             layer->time_mix_w2, | ||||
|             layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5 | ||||
|         ), | ||||
|         xxx | ||||
|     ); | ||||
|  | ||||
|     struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], 0); | ||||
|     struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float)); | ||||
|     struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float)); | ||||
|     struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float)); | ||||
|     struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float)); | ||||
|  | ||||
|     struct ggml_tensor * xw = ggml_add( | ||||
|         ctx, | ||||
|         ggml_mul( | ||||
|             ctx, | ||||
|             ggml_add(ctx, mw, layer->time_mix_lerp_w), | ||||
|             sx | ||||
|         ), | ||||
|         cur | ||||
|     ); | ||||
|  | ||||
|     struct ggml_tensor * xk = ggml_add( | ||||
|         ctx, | ||||
|         ggml_mul( | ||||
|             ctx, | ||||
|             ggml_add(ctx, mk, layer->time_mix_lerp_k), | ||||
|             sx | ||||
|         ), | ||||
|         cur | ||||
|     ); | ||||
|  | ||||
|     struct ggml_tensor * xv = ggml_add( | ||||
|         ctx, | ||||
|         ggml_mul( | ||||
|             ctx, | ||||
|             ggml_add(ctx, mv, layer->time_mix_lerp_v), | ||||
|             sx | ||||
|         ), | ||||
|         cur | ||||
|     ); | ||||
|  | ||||
|     struct ggml_tensor * xr = ggml_add( | ||||
|         ctx, | ||||
|         ggml_mul( | ||||
|             ctx, | ||||
|             ggml_add(ctx, mr, layer->time_mix_lerp_r), | ||||
|             sx | ||||
|         ), | ||||
|         cur | ||||
|     ); | ||||
|  | ||||
|     struct ggml_tensor * xg = ggml_add( | ||||
|         ctx, | ||||
|         ggml_mul( | ||||
|             ctx, | ||||
|             ggml_add(ctx, mg, layer->time_mix_lerp_g), | ||||
|             sx | ||||
|         ), | ||||
|         cur | ||||
|     ); | ||||
|  | ||||
|     struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens); | ||||
|     struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens); | ||||
|     struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens); | ||||
|     struct ggml_tensor * g = ggml_silu( | ||||
|         ctx, | ||||
|         llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg) | ||||
|     ); | ||||
|  | ||||
|     struct ggml_tensor * w = ggml_mul_mat( | ||||
|         ctx, | ||||
|         layer->time_mix_decay_w2, | ||||
|         ggml_tanh( | ||||
|             ctx, | ||||
|             ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw) | ||||
|         ) | ||||
|     ); | ||||
|  | ||||
|     w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)); | ||||
|     w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); | ||||
|     w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); | ||||
|  | ||||
|     k = ggml_transpose(ctx, k); | ||||
|     v = ggml_transpose(ctx, v); | ||||
|     r = ggml_transpose(ctx, r); | ||||
|  | ||||
|     struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); | ||||
|     cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); | ||||
|     *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float)); | ||||
|  | ||||
|     // group norm with head_count groups | ||||
|     cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens); | ||||
|     cur = ggml_norm(ctx, cur, 64e-5f); | ||||
|  | ||||
|     // Convert back to regular vectors. | ||||
|     cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens); | ||||
|     cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); | ||||
|  | ||||
|     cur = ggml_mul(ctx, cur, g); | ||||
|     cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur); | ||||
|  | ||||
|     return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs); | ||||
| } | ||||
|  | ||||
| static struct ggml_tensor * llm_build_rwkv6_channel_mix( | ||||
|         struct llama_context & lctx, | ||||
|         struct ggml_context * ctx, | ||||
|         const struct llama_layer * layer, | ||||
|         struct ggml_tensor * cur, | ||||
|         struct ggml_tensor * x_prev) { | ||||
|     struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); | ||||
|     struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur); | ||||
|     struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur); | ||||
|  | ||||
|     struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr)); | ||||
|     struct ggml_tensor * k = ggml_sqr( | ||||
|         ctx, | ||||
|         ggml_relu( | ||||
|             ctx, | ||||
|             llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk) | ||||
|         ) | ||||
|     ); | ||||
|  | ||||
|     return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k)); | ||||
| } | ||||
|  | ||||
| struct llm_build_context { | ||||
|     const llama_model    & model; | ||||
|           llama_context  & lctx; | ||||
| @@ -14683,6 +15072,117 @@ struct llm_build_context { | ||||
|  | ||||
|         return gf; | ||||
|     } | ||||
|  | ||||
|     ggml_cgraph * build_rwkv6() { | ||||
|         ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); | ||||
|  | ||||
|         // Token shift state dimensions should be 2 * n_emb | ||||
|         GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); | ||||
|  | ||||
|         const int64_t n_seqs = batch.n_seqs; | ||||
|         const int64_t n_seq_tokens = batch.n_seq_tokens; | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|         GGML_ASSERT(n_seqs != 0); | ||||
|         GGML_ASSERT(batch.equal_seqs); | ||||
|         GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs); | ||||
|  | ||||
|         struct ggml_tensor * cur; | ||||
|         struct ggml_tensor * inpL; | ||||
|         struct ggml_tensor * state_copy = build_inp_s_copy(); | ||||
|         struct ggml_tensor * state_mask = build_inp_s_mask(); | ||||
|  | ||||
|         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); | ||||
|         inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); | ||||
|  | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|             const llama_layer * layer = &model.layers[il]; | ||||
|  | ||||
|             // (ab)using the KV cache to store the states | ||||
|             struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, | ||||
|                     gf, kv_self.k_l[il], state_copy, state_mask, | ||||
|                     hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); | ||||
|             struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, | ||||
|                     gf, kv_self.v_l[il], state_copy, state_mask, | ||||
|                     hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); | ||||
|  | ||||
|             cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); | ||||
|             token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs); | ||||
|  | ||||
|             struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); | ||||
|             struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); | ||||
|  | ||||
|             struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il); | ||||
|             struct ggml_tensor * x_prev = ggml_concat( | ||||
|                 ctx0, | ||||
|                 att_shift, | ||||
|                 ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0), | ||||
|                 1 | ||||
|             ); | ||||
|  | ||||
|             cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states)); | ||||
|             ggml_build_forward_expand(gf, cur); | ||||
|             ggml_build_forward_expand( | ||||
|                 gf, | ||||
|                 ggml_cpy( | ||||
|                     ctx0, | ||||
|                     wkv_states, | ||||
|                     ggml_view_1d( | ||||
|                         ctx0, | ||||
|                         kv_self.v_l[il], | ||||
|                         hparams.n_embd_v_s() * n_seqs, | ||||
|                         hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) | ||||
|                     ) | ||||
|                 ) | ||||
|             ); | ||||
|  | ||||
|             struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il); | ||||
|             x_prev = ggml_concat( | ||||
|                 ctx0, | ||||
|                 ffn_shift, | ||||
|                 ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), | ||||
|                 1 | ||||
|             ); | ||||
|             cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev)); | ||||
|             ggml_build_forward_expand(gf, cur); | ||||
|  | ||||
|             struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); | ||||
|             struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn)); | ||||
|  | ||||
|             token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1); | ||||
|  | ||||
|             ggml_build_forward_expand( | ||||
|                 gf, | ||||
|                 ggml_cpy( | ||||
|                     ctx0, | ||||
|                     ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0), | ||||
|                     ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) | ||||
|                 ) | ||||
|             ); | ||||
|  | ||||
|             if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) { | ||||
|                 cur = ggml_scale(ctx0, cur, 0.5F); | ||||
|             } | ||||
|  | ||||
|             cur = lctx.cvec.apply_to(ctx0, cur, il); | ||||
|             cb(cur, "l_out", il); | ||||
|  | ||||
|             // input for next layer | ||||
|             inpL = cur; | ||||
|         } | ||||
|  | ||||
|         cur = inpL; | ||||
|         struct ggml_tensor * inp_out_ids = build_inp_out_ids(); | ||||
|         cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); | ||||
|         cur = ggml_get_rows(ctx0, cur, inp_out_ids); | ||||
|  | ||||
|         cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); | ||||
|         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); | ||||
|  | ||||
|         cb(cur, "result_output", -1); | ||||
|         ggml_build_forward_expand(gf, cur); | ||||
|  | ||||
|         return gf; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) { | ||||
| @@ -14929,6 +15429,10 @@ static struct ggml_cgraph * llama_build_graph( | ||||
|             { | ||||
|                 result = llm.build_exaone(); | ||||
|             } break; | ||||
|         case LLM_ARCH_RWKV6: | ||||
|             { | ||||
|                 result = llm.build_rwkv6(); | ||||
|             } break; | ||||
|         default: | ||||
|             GGML_ABORT("fatal error"); | ||||
|     } | ||||
| @@ -16973,6 +17477,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s | ||||
|         // NOTE: can't use LLM_TN here because the layer number is not known | ||||
|         quantize &= name.find("ssm_conv1d.weight") == std::string::npos; | ||||
|  | ||||
|         // do not quantize RWKV's time_mix_first tensors | ||||
|         quantize &= name.find("time_mix_first.weight") == std::string::npos; | ||||
|         quantize &= name.find("time_mix_w1.weight") == std::string::npos; | ||||
|         quantize &= name.find("time_mix_w2.weight") == std::string::npos; | ||||
|  | ||||
|         // do not quantize relative position bias (T5) | ||||
|         quantize &= name.find("attn_rel_b.weight") == std::string::npos; | ||||
|  | ||||
| @@ -17977,6 +18486,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { | ||||
|         case LLM_ARCH_T5: | ||||
|         case LLM_ARCH_T5ENCODER: | ||||
|         case LLM_ARCH_JAIS: | ||||
|         case LLM_ARCH_RWKV6: | ||||
|             return LLAMA_ROPE_TYPE_NONE; | ||||
|  | ||||
|         // use what we call a normal RoPE, operating on pairs of consecutive head values | ||||
| @@ -18145,6 +18655,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { | ||||
| bool llama_model_is_recurrent(const struct llama_model * model) { | ||||
|     switch (model->arch) { | ||||
|         case LLM_ARCH_MAMBA:  return true; | ||||
|         case LLM_ARCH_RWKV6:  return true; | ||||
|         default:              return false; | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Molly Sophia
					Molly Sophia