mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			138 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			138 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #pragma once
 | |
| 
 | |
| #include "llama-batch.h"
 | |
| #include "llama-graph.h"
 | |
| #include "llama-kv-cache.h"
 | |
| #include "llama-memory.h"
 | |
| #include "llama-memory-recurrent.h"
 | |
| 
 | |
| #include <memory>
 | |
| #include <vector>
 | |
| 
 | |
| //
 | |
| // llama_memory_hybrid
 | |
| //
 | |
| 
 | |
| // utilizes instances of llama_memory_recurrent and llama_kv_cache to
 | |
| //   support models where each layer may be either attention-based or recurrent
 | |
| 
 | |
| class llama_memory_hybrid : public llama_memory_i {
 | |
| public:
 | |
|     llama_memory_hybrid(
 | |
|         const llama_model & model,
 | |
|                             /* attn */
 | |
|                 ggml_type   type_k,
 | |
|                 ggml_type   type_v,
 | |
|                      bool   v_trans,
 | |
|                  uint32_t   kv_size,
 | |
|                  uint32_t   n_pad,
 | |
|                  uint32_t   n_swa,
 | |
|            llama_swa_type   swa_type,
 | |
|                             /* recurrent */
 | |
|                 ggml_type   type_r,
 | |
|                 ggml_type   type_s,
 | |
|                  uint32_t   rs_size,
 | |
|                             /* common */
 | |
|                  uint32_t   n_seq_max,
 | |
|                      bool   offload,
 | |
|                      bool   unified,
 | |
|                             /* layer filters */
 | |
|     const layer_filter_cb & filter_attn = nullptr,
 | |
|     const layer_filter_cb & filter_recr = nullptr);
 | |
| 
 | |
|     ~llama_memory_hybrid() = default;
 | |
| 
 | |
|     //
 | |
|     // llama_memory_i
 | |
|     //
 | |
| 
 | |
|     llama_memory_context_ptr init_batch(
 | |
|             llama_batch_allocr & balloc,
 | |
|             uint32_t n_ubatch,
 | |
|             bool embd_all) override;
 | |
| 
 | |
|     llama_memory_context_ptr init_full() override;
 | |
| 
 | |
|     llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 | |
| 
 | |
|     bool get_can_shift() const override;
 | |
| 
 | |
|     void clear(bool data) override;
 | |
| 
 | |
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
 | |
|     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
 | |
|     void seq_keep(llama_seq_id seq_id)                                                          override;
 | |
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
 | |
|     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 | |
| 
 | |
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
 | |
|     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 | |
| 
 | |
|     // state write/load
 | |
| 
 | |
|     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
 | |
|     void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0)       override;
 | |
| 
 | |
|     //
 | |
|     // llama_memory_hybrid specific API
 | |
|     //
 | |
| 
 | |
|     llama_kv_cache * get_mem_attn() const;
 | |
|     llama_memory_recurrent * get_mem_recr() const;
 | |
| 
 | |
| private:
 | |
|     const llama_hparams & hparams;
 | |
| 
 | |
|     const std::unique_ptr<llama_kv_cache> mem_attn;
 | |
|     const std::unique_ptr<llama_memory_recurrent> mem_recr;
 | |
| };
 | |
| 
 | |
| class llama_memory_hybrid_context : public llama_memory_context_i {
 | |
| public:
 | |
|     using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
 | |
| 
 | |
|     // init failure
 | |
|     explicit llama_memory_hybrid_context(llama_memory_status status);
 | |
| 
 | |
|     // init full
 | |
|     explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
 | |
| 
 | |
|     // init update
 | |
|     explicit llama_memory_hybrid_context(
 | |
|         llama_memory_hybrid * mem,
 | |
|               llama_context * lctx,
 | |
|                        bool   optimize);
 | |
| 
 | |
|     // init success
 | |
|     llama_memory_hybrid_context(
 | |
|               llama_memory_hybrid * mem,
 | |
|                   slot_info_vec_t   sinfos_attn,
 | |
|         std::vector<llama_ubatch>   ubatches);
 | |
| 
 | |
|     ~llama_memory_hybrid_context() = default;
 | |
| 
 | |
|     bool next()  override;
 | |
|     bool apply() override;
 | |
| 
 | |
|     llama_memory_status  get_status() const override;
 | |
|     const llama_ubatch & get_ubatch() const override;
 | |
| 
 | |
|     //
 | |
|     // llama_memory_hybrid_context
 | |
|     //
 | |
| 
 | |
|     const llama_kv_cache_context * get_attn() const;
 | |
|     const llama_memory_recurrent_context * get_recr() const;
 | |
| 
 | |
| private:
 | |
|     // the index of the next ubatch to process
 | |
|     size_t i_next = 0;
 | |
| 
 | |
|     std::vector<llama_ubatch> ubatches;
 | |
| 
 | |
|     const llama_memory_context_ptr ctx_attn;
 | |
|     const llama_memory_context_ptr ctx_recr;
 | |
| 
 | |
|     const llama_memory_status status;
 | |
| };
 | 
