mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : avoid redundant state copy for Mamba 1 and 2
This commit is contained in:
		| @@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case { | ||||
|  | ||||
|     const int64_t d_state; | ||||
|     const int64_t d_inner; | ||||
|     const int64_t n_head; | ||||
|     const int64_t n_group; | ||||
|     const int64_t n_seq_tokens; | ||||
|     const int64_t n_seqs; | ||||
|  | ||||
|     std::string vars() override { | ||||
|         return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); | ||||
|         return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); | ||||
|     } | ||||
|  | ||||
|     test_ssm_scan(ggml_type type = GGML_TYPE_F32, | ||||
|             int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) | ||||
|         : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} | ||||
|             int64_t d_state = 32, | ||||
|             int64_t d_inner = 1, // non-zero for Mamba-2 | ||||
|             int64_t n_head  = 32, | ||||
|             int64_t n_group = 1, | ||||
|             int64_t n_seq_tokens = 32, | ||||
|             int64_t n_seqs = 32) | ||||
|         : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} | ||||
|  | ||||
|     ggml_tensor * build_graph(ggml_context * ctx) override { | ||||
|         ggml_tensor * s   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner,      n_seqs, 1 }.data()); | ||||
|         ggml_tensor * x   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); | ||||
|         ggml_tensor * dt  = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); | ||||
|         ggml_tensor * A   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner,      1     , 1 }.data()); | ||||
|         ggml_tensor * B   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data()); | ||||
|         ggml_tensor * C   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data()); | ||||
|         ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); | ||||
|         ggml_tensor * s   = ggml_new_tensor_4d(ctx, type, d_state, d_inner,      n_head,       n_seqs); | ||||
|         ggml_tensor * x   = ggml_new_tensor_4d(ctx, type, d_inner, n_head,       n_seq_tokens, n_seqs); | ||||
|         ggml_tensor * dt  = ggml_new_tensor_3d(ctx, type, n_head,  n_seq_tokens, n_seqs); | ||||
|         ggml_tensor * A   = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); | ||||
|         ggml_tensor * B   = ggml_new_tensor_4d(ctx, type, d_state, n_group,      n_seq_tokens, n_seqs); | ||||
|         ggml_tensor * C   = ggml_new_tensor_4d(ctx, type, d_state, n_group,      n_seq_tokens, n_seqs); | ||||
|         ggml_tensor * D   = ggml_new_tensor_1d(ctx, type, n_head); | ||||
|         ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); | ||||
|         ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); | ||||
|         return out; | ||||
|     } | ||||
|  | ||||
|     // similar to test_mul_mat_id | ||||
|     void initialize_tensors(ggml_context * ctx) override { | ||||
|         std::random_device rd; | ||||
|         std::default_random_engine rng(rd()); | ||||
|         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { | ||||
|             if (t->type == GGML_TYPE_I32) { | ||||
|                 if (ggml_is_view_op(t->op)) { continue; } | ||||
|                 // ids | ||||
|                 for (int64_t r = 0; r < ggml_nrows(t); r++) { | ||||
|                     std::vector<int32_t> data(t->ne[0]); | ||||
|                     for (int i = 0; i < t->ne[0]; i++) { | ||||
|                         data[i] = i; | ||||
|                     } | ||||
|                     std::shuffle(data.begin(), data.end(), rng); | ||||
|                     ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); | ||||
|                 } | ||||
|             } else { | ||||
|                 init_tensor_uniform(t); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| // GGML_OP_MUL_MAT | ||||
| @@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op | ||||
|     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); | ||||
|     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); | ||||
|  | ||||
|     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); | ||||
|     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 | ||||
|     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 | ||||
|  | ||||
| #if 1 | ||||
|     for (ggml_type type_a : base_types) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin