mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	llama : support RWKV v6 models (#8980)
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV v6 graph building Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Remove trailing whitespaces Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update src/llama.cpp Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <git@compilade.net> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Co-authored-by: compilade <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		@@ -58,17 +58,17 @@ struct naive_trie {
 | 
			
		||||
        auto res = children.find(c);
 | 
			
		||||
        if (res != children.end()) {
 | 
			
		||||
            return res->second.get_longest_prefix(key, len, offset + 1);
 | 
			
		||||
        } else {
 | 
			
		||||
            return std::make_pair(key, offset);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return std::make_pair(key, offset);
 | 
			
		||||
    }
 | 
			
		||||
    struct naive_trie * traverse(const char c) {
 | 
			
		||||
    const struct naive_trie * traverse(const char c) const {
 | 
			
		||||
        auto res = children.find(c);
 | 
			
		||||
        if (res != children.end()) {
 | 
			
		||||
            return &res->second;
 | 
			
		||||
        } else {
 | 
			
		||||
            return NULL;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return NULL;
 | 
			
		||||
    }
 | 
			
		||||
    std::map<char, struct naive_trie> children;
 | 
			
		||||
    bool has_value;
 | 
			
		||||
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
 | 
			
		||||
            // traverse the token matcher trie to find a matching token
 | 
			
		||||
            bool single_codepoint_token_found = false;
 | 
			
		||||
            const struct best_tokenization & current_best = tokenization_results[input_offset];
 | 
			
		||||
            struct naive_trie * node  = token_matcher.traverse(normalized[prefix_offset++]);
 | 
			
		||||
            const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
 | 
			
		||||
 | 
			
		||||
            while (prefix_offset <= input_len && node != NULL) {
 | 
			
		||||
                // check if we found valid token in prefix
 | 
			
		||||
@@ -1097,6 +1097,111 @@ private:
 | 
			
		||||
    struct naive_trie token_matcher;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// RWKV tokenizer
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
 | 
			
		||||
    std::vector<uint8_t> output;
 | 
			
		||||
    output.reserve(escaped.size());
 | 
			
		||||
 | 
			
		||||
    // Parser state
 | 
			
		||||
    bool escaping = false;
 | 
			
		||||
    uint8_t hex_remaining = 0;
 | 
			
		||||
    uint8_t hex_acc = 0;
 | 
			
		||||
 | 
			
		||||
    // Step through characters, performing parsing
 | 
			
		||||
    for (const char & c : escaped) {
 | 
			
		||||
        // If we're parsing a hex code, interpret the next character
 | 
			
		||||
        if (hex_remaining != 0) {
 | 
			
		||||
            uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
 | 
			
		||||
            hex_acc = (hex_acc << 4) + value;
 | 
			
		||||
 | 
			
		||||
            hex_remaining -= 1;
 | 
			
		||||
            if (hex_remaining == 0) {
 | 
			
		||||
                output.push_back(hex_acc);
 | 
			
		||||
                hex_acc = 0;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // If we got an escape character, interpret it
 | 
			
		||||
        if (escaping) {
 | 
			
		||||
            if (c == 't') {
 | 
			
		||||
                output.push_back('\t');
 | 
			
		||||
            } else if (c == 'n') {
 | 
			
		||||
                output.push_back('\n');
 | 
			
		||||
            } else if (c == 'r') {
 | 
			
		||||
                output.push_back('\r');
 | 
			
		||||
            } else if (c == 'x') {
 | 
			
		||||
                hex_remaining = 2;
 | 
			
		||||
            } else {
 | 
			
		||||
                output.push_back(c);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            escaping = false;
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (c == '\\') {
 | 
			
		||||
            escaping = true;
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        output.push_back(c);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct llm_tokenizer_rwkv {
 | 
			
		||||
    llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
 | 
			
		||||
        // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
 | 
			
		||||
        // For now, we decode the vocab here into the lookup we'll use for tokenization.
 | 
			
		||||
 | 
			
		||||
        // build trie
 | 
			
		||||
        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
 | 
			
		||||
            const auto & token = vocab.id_to_token[id];
 | 
			
		||||
            const auto data = llama_unescape_rwkv_token(token.text);
 | 
			
		||||
            token_matcher.insert((const char *) data.data(), data.size(), id);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
 | 
			
		||||
        uint32_t position = 0;
 | 
			
		||||
 | 
			
		||||
        while (position < text.size()) {
 | 
			
		||||
            const struct naive_trie * node = token_matcher.traverse(text[position]);
 | 
			
		||||
            if (node == NULL) {
 | 
			
		||||
                // no matching token found, add unknown token
 | 
			
		||||
                output.push_back(vocab.special_unk_id);
 | 
			
		||||
                position += 1;
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // traverse the trie to find the longest matching token
 | 
			
		||||
            uint32_t token_id = 0;
 | 
			
		||||
            uint32_t token_length = 0;
 | 
			
		||||
            while (node != NULL) {
 | 
			
		||||
                if (node->has_value) {
 | 
			
		||||
                    token_id = node->value;
 | 
			
		||||
                    token_length = position + 1;
 | 
			
		||||
                }
 | 
			
		||||
                node = node->traverse(text[++position]);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // add the longest matching token
 | 
			
		||||
            output.push_back(token_id);
 | 
			
		||||
            position = token_length;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const llama_vocab & vocab;
 | 
			
		||||
 | 
			
		||||
    struct naive_trie token_matcher;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// (de-) tokenize
 | 
			
		||||
//
 | 
			
		||||
@@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
 | 
			
		||||
                    output.push_back(vocab.special_eos_id);
 | 
			
		||||
                }
 | 
			
		||||
            } break;
 | 
			
		||||
        case LLAMA_VOCAB_TYPE_RWKV:
 | 
			
		||||
            {
 | 
			
		||||
                for (const auto & fragment : fragment_buffer) {
 | 
			
		||||
                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
 | 
			
		||||
                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 | 
			
		||||
 | 
			
		||||
#ifdef PRETOKENIZERDEBUG
 | 
			
		||||
                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
                        llm_tokenizer_rwkv tokenizer(vocab);
 | 
			
		||||
                        tokenizer.tokenize(raw_text, output);
 | 
			
		||||
                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
 | 
			
		||||
                        output.push_back(fragment.token);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            } break;
 | 
			
		||||
        case LLAMA_VOCAB_TYPE_NONE:
 | 
			
		||||
            GGML_ABORT("fatal error");
 | 
			
		||||
    }
 | 
			
		||||
@@ -1616,6 +1738,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
 | 
			
		||||
                }
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            case LLAMA_VOCAB_TYPE_RWKV: {
 | 
			
		||||
                std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
 | 
			
		||||
 | 
			
		||||
                // If we don't have enough space, return an error
 | 
			
		||||
                if (result.size() > (size_t)length) {
 | 
			
		||||
                    return -(int)result.size();
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                memcpy(buf, result.data(), result.size());
 | 
			
		||||
                return (int)result.size();
 | 
			
		||||
            }
 | 
			
		||||
            default:
 | 
			
		||||
                GGML_ABORT("fatal error");
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user