mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	test-model-random : add Mamba2
This commit is contained in:
		| @@ -628,6 +628,64 @@ struct model_variant { | ||||
|                     } | ||||
|                 } | ||||
|                 break; | ||||
|             case LLM_ARCH_MAMBA2: | ||||
|                 { | ||||
|                     variants.push_back(model_variant(arch, "Mamba2")); | ||||
|                     model_variant & cur = variants.back(); | ||||
|  | ||||
|                     n_embd = 64; | ||||
|  | ||||
|                     const uint32_t d_inner = 2 * n_embd; | ||||
|                     const uint32_t d_conv = 4; | ||||
|                     const uint32_t d_state = 128; | ||||
|                     const uint32_t n_group = 2; | ||||
|                     const uint32_t head_dim = 64; | ||||
|                     const uint32_t n_head = d_inner / head_dim; | ||||
|                     const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; | ||||
|  | ||||
|                     const auto init_A = [](std::mt19937 & rng) { | ||||
|                         return -std::uniform_real_distribution<float>(1, 16)(rng); | ||||
|                     }; | ||||
|  | ||||
|                     cur.add_kv(LLM_KV_CONTEXT_LENGTH, (uint32_t) 1024 * 1024); | ||||
|                     cur.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd); | ||||
|                     cur.add_kv(LLM_KV_FEED_FORWARD_LENGTH, (uint32_t) 0); | ||||
|                     cur.add_kv(LLM_KV_ATTENTION_HEAD_COUNT, (uint32_t) 0); | ||||
|                     cur.add_kv(LLM_KV_BLOCK_COUNT, n_layer); | ||||
|                     cur.add_kv(LLM_KV_SSM_CONV_KERNEL, d_conv); | ||||
|                     cur.add_kv(LLM_KV_SSM_INNER_SIZE, d_inner); | ||||
|                     cur.add_kv(LLM_KV_SSM_STATE_SIZE, d_state); | ||||
|                     cur.add_kv(LLM_KV_SSM_TIME_STEP_RANK, n_head); | ||||
|                     cur.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f); | ||||
|                     cur.add_kv(LLM_KV_SSM_GROUP_COUNT, n_group); | ||||
|  | ||||
|                     add_tokenizer(cur, n_vocab); | ||||
|  | ||||
|                     cur.add_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); | ||||
|                     cur.add_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); | ||||
|                     cur.add_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); | ||||
|  | ||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); | ||||
|  | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}); | ||||
|  | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}); | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, init_bias); | ||||
|  | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, init_bias); | ||||
|  | ||||
|                         // no "weight" suffix for these | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, init_A); | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, []() { return 1.0f; }); | ||||
|  | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}); | ||||
|  | ||||
|                         // out_proj | ||||
|                         cur.add_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); | ||||
|                     } | ||||
|                 } | ||||
|                 break; | ||||
|             case LLM_ARCH_XVERSE: | ||||
|             case LLM_ARCH_COMMAND_R: | ||||
|             case LLM_ARCH_COHERE2: | ||||
| @@ -760,6 +818,7 @@ struct model_variant { | ||||
|             case LLM_ARCH_BAILINGMOE: | ||||
|             case LLM_ARCH_DOTS1: | ||||
|             case LLM_ARCH_ARCEE: | ||||
|             case LLM_ARCH_ERNIE4_5: | ||||
|             case LLM_ARCH_UNKNOWN: | ||||
|                 break; | ||||
|         } | ||||
| @@ -1042,6 +1101,8 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 std::vector<reference_logits> ref_outputs; | ||||
|  | ||||
|                 fprintf(stdout, "Generating reference outputs for '%s', n_seq_max=%i...\n", variant.name.c_str(), n_seq_max); | ||||
|  | ||||
|                 { | ||||
|                     llama_context_params ref_params = llama_context_default_params(); | ||||
|                     ref_params.n_batch = n_seq_len; | ||||
| @@ -1092,6 +1153,10 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                         float max_err = 0.0f; | ||||
|  | ||||
|                         fprintf(stdout, | ||||
|                                 "Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ", | ||||
|                                 variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch); | ||||
|  | ||||
|                         // start filling the batch with prompts | ||||
|                         while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(), | ||||
|                                            [](llama_pos p) { return p < n_seq_len; })) { | ||||
| @@ -1140,9 +1205,6 @@ int main(int argc, char ** argv) { | ||||
|                             } | ||||
|                         } | ||||
|  | ||||
|                         fprintf(stdout, | ||||
|                                 "Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ", | ||||
|                                 variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch); | ||||
|                         if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) { | ||||
|                             fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err); | ||||
|                         } else { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin