mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	graph : reuse hybrid graphs
This commit is contained in:
		| @@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { | |||||||
| } | } | ||||||
|  |  | ||||||
| void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { | void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { | ||||||
|     inp_attn->set_input(ubatch); |     mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); | ||||||
|     inp_rs->set_input(ubatch); |     mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); | ||||||
|  |  | ||||||
|  |     mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); | ||||||
|  |  | ||||||
|  |     const int64_t n_rs = mctx->get_recr()->get_n_rs(); | ||||||
|  |  | ||||||
|  |     if (inp_rs->s_copy) { | ||||||
|  |         GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); | ||||||
|  |         int32_t * data = (int32_t *) inp_rs->s_copy->data; | ||||||
|  |  | ||||||
|  |         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n | ||||||
|  |         for (uint32_t i = 0; i < n_rs; ++i) { | ||||||
|  |             data[i] = mctx->get_recr()->s_copy(i); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { | ||||||
|  |     const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx); | ||||||
|  |  | ||||||
|  |     this->mctx = mctx; | ||||||
|  |  | ||||||
|  |     bool res = true; | ||||||
|  |  | ||||||
|  |     res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; | ||||||
|  |   //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there | ||||||
|  |  | ||||||
|  |     res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); | ||||||
|  |     res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); | ||||||
|  |  | ||||||
|  |     res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); | ||||||
|  |  | ||||||
|  |     res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs; | ||||||
|  |     res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; | ||||||
|  |  | ||||||
|  |     return res; | ||||||
| } | } | ||||||
|  |  | ||||||
| // | // | ||||||
| @@ -1909,7 +1944,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { | |||||||
|     auto inp_rs   = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); |     auto inp_rs   = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); | ||||||
|     auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); |     auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); | ||||||
|  |  | ||||||
|     auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur); |     auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); | ||||||
|  |  | ||||||
|     return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); |     return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -364,22 +364,28 @@ public: | |||||||
| class llm_graph_input_mem_hybrid : public llm_graph_input_i { | class llm_graph_input_mem_hybrid : public llm_graph_input_i { | ||||||
| public: | public: | ||||||
|     llm_graph_input_mem_hybrid( |     llm_graph_input_mem_hybrid( | ||||||
|  |             const llama_cparams & cparams, | ||||||
|             std::unique_ptr<llm_graph_input_attn_kv> inp_attn, |             std::unique_ptr<llm_graph_input_attn_kv> inp_attn, | ||||||
|             std::unique_ptr<llm_graph_input_rs>      inp_rs, |             std::unique_ptr<llm_graph_input_rs>      inp_rs, | ||||||
|             const llama_memory_hybrid_context *      mctx) : |             const llama_memory_hybrid_context *      mctx) : | ||||||
|         inp_attn(std::move(inp_attn)), |         inp_attn(std::move(inp_attn)), | ||||||
|         inp_rs(std::move(inp_rs)), |         inp_rs(std::move(inp_rs)), | ||||||
|  |         cparams(cparams), | ||||||
|         mctx(mctx) { } |         mctx(mctx) { } | ||||||
|     virtual ~llm_graph_input_mem_hybrid() = default; |     virtual ~llm_graph_input_mem_hybrid() = default; | ||||||
|  |  | ||||||
|     void set_input(const llama_ubatch * ubatch) override; |     void set_input(const llama_ubatch * ubatch) override; | ||||||
|  |  | ||||||
|  |     bool can_reuse(const llm_graph_params & params) override; | ||||||
|  |  | ||||||
|     std::unique_ptr<llm_graph_input_attn_kv> inp_attn; |     std::unique_ptr<llm_graph_input_attn_kv> inp_attn; | ||||||
|     std::unique_ptr<llm_graph_input_rs>      inp_rs; |     std::unique_ptr<llm_graph_input_rs>      inp_rs; | ||||||
|  |  | ||||||
|     llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } |     llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } | ||||||
|     llm_graph_input_rs      * get_recr() const { return inp_rs.get(); } |     llm_graph_input_rs      * get_recr() const { return inp_rs.get(); } | ||||||
|  |  | ||||||
|  |     const llama_cparams cparams; | ||||||
|  |  | ||||||
|     const llama_memory_hybrid_context * mctx; |     const llama_memory_hybrid_context * mctx; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov