mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
test-model-random : add Mamba2
This commit is contained in:
@@ -628,6 +628,64 @@ struct model_variant {
|
||||
}
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_MAMBA2:
|
||||
{
|
||||
variants.push_back(model_variant(arch, "Mamba2"));
|
||||
model_variant & cur = variants.back();
|
||||
|
||||
n_embd = 64;
|
||||
|
||||
const uint32_t d_inner = 2 * n_embd;
|
||||
const uint32_t d_conv = 4;
|
||||
const uint32_t d_state = 128;
|
||||
const uint32_t n_group = 2;
|
||||
const uint32_t head_dim = 64;
|
||||
const uint32_t n_head = d_inner / head_dim;
|
||||
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
|
||||
|
||||
const auto init_A = [](std::mt19937 & rng) {
|
||||
return -std::uniform_real_distribution<float>(1, 16)(rng);
|
||||
};
|
||||
|
||||
cur.add_kv(LLM_KV_CONTEXT_LENGTH, (uint32_t) 1024 * 1024);
|
||||
cur.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd);
|
||||
cur.add_kv(LLM_KV_FEED_FORWARD_LENGTH, (uint32_t) 0);
|
||||
cur.add_kv(LLM_KV_ATTENTION_HEAD_COUNT, (uint32_t) 0);
|
||||
cur.add_kv(LLM_KV_BLOCK_COUNT, n_layer);
|
||||
cur.add_kv(LLM_KV_SSM_CONV_KERNEL, d_conv);
|
||||
cur.add_kv(LLM_KV_SSM_INNER_SIZE, d_inner);
|
||||
cur.add_kv(LLM_KV_SSM_STATE_SIZE, d_state);
|
||||
cur.add_kv(LLM_KV_SSM_TIME_STEP_RANK, n_head);
|
||||
cur.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
|
||||
cur.add_kv(LLM_KV_SSM_GROUP_COUNT, n_group);
|
||||
|
||||
add_tokenizer(cur, n_vocab);
|
||||
|
||||
cur.add_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
|
||||
cur.add_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
|
||||
cur.add_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab });
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
cur.add_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj});
|
||||
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state});
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, init_bias);
|
||||
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, init_bias);
|
||||
|
||||
// no "weight" suffix for these
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, init_A);
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, []() { return 1.0f; });
|
||||
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group});
|
||||
|
||||
// out_proj
|
||||
cur.add_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
|
||||
}
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_XVERSE:
|
||||
case LLM_ARCH_COMMAND_R:
|
||||
case LLM_ARCH_COHERE2:
|
||||
@@ -760,6 +818,7 @@ struct model_variant {
|
||||
case LLM_ARCH_BAILINGMOE:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_ARCEE:
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_UNKNOWN:
|
||||
break;
|
||||
}
|
||||
@@ -1042,6 +1101,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
std::vector<reference_logits> ref_outputs;
|
||||
|
||||
fprintf(stdout, "Generating reference outputs for '%s', n_seq_max=%i...\n", variant.name.c_str(), n_seq_max);
|
||||
|
||||
{
|
||||
llama_context_params ref_params = llama_context_default_params();
|
||||
ref_params.n_batch = n_seq_len;
|
||||
@@ -1092,6 +1153,10 @@ int main(int argc, char ** argv) {
|
||||
|
||||
float max_err = 0.0f;
|
||||
|
||||
fprintf(stdout,
|
||||
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
|
||||
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
|
||||
|
||||
// start filling the batch with prompts
|
||||
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
|
||||
[](llama_pos p) { return p < n_seq_len; })) {
|
||||
@@ -1140,9 +1205,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout,
|
||||
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
|
||||
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
|
||||
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
|
||||
fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user