mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	* kv-cache : prepare K/V buffers for separation ggml-ci * batched-bench : fix oob write ggml-ci * llama : add "virtual sequences" ggml-ci * llama : use "stream" vs "virtual sequence" ggml-ci * graph : fix stream splitting when KV cache is not used ggml-ci * kv-cache : add multi-stream save/load support ggml-ci * llama : add "--attn-streams" flag ggml-ci * kv-cache : fix handling when find_slot fails ggml-ci * kv-cache : restore find_slot impl ggml-ci * kv-cache : add comments * kv-cache : add bounds checks for sequence id ggml-ci * cont : add n_seq_max to batch allocr ggml-ci * kv-cache : perform stream copies lazily after llama_synchronize ggml-ci * kv-cache : avoid throwing exceptions across the C boundary ggml-ci * CUDA: 4D FlashAttention support (#14628) * CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel * llama : rename attn_streams -> kv_unified ggml-ci * common : rename kv_split -> kv_unified ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
		
			
				
	
	
		
			134 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			134 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#pragma once
 | 
						|
 | 
						|
#include "llama-kv-cache-unified.h"
 | 
						|
 | 
						|
#include <vector>
 | 
						|
 | 
						|
//
 | 
						|
// llama_kv_cache_unified_iswa
 | 
						|
//
 | 
						|
 | 
						|
// utilizes two instances of llama_kv_cache_unified
 | 
						|
//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
 | 
						|
 | 
						|
class llama_kv_cache_unified_iswa : public llama_memory_i {
 | 
						|
public:
 | 
						|
    llama_kv_cache_unified_iswa(
 | 
						|
            const llama_model & model,
 | 
						|
                    ggml_type   type_k,
 | 
						|
                    ggml_type   type_v,
 | 
						|
                         bool   v_trans,
 | 
						|
                         bool   offload,
 | 
						|
                         bool   swa_full,
 | 
						|
                         bool   unified,
 | 
						|
                     uint32_t   kv_size,
 | 
						|
                     uint32_t   n_seq_max,
 | 
						|
                     uint32_t   n_ubatch,
 | 
						|
                     uint32_t   n_pad);
 | 
						|
 | 
						|
    ~llama_kv_cache_unified_iswa() = default;
 | 
						|
 | 
						|
    //
 | 
						|
    // llama_memory_i
 | 
						|
    //
 | 
						|
 | 
						|
    llama_memory_context_ptr init_batch(
 | 
						|
            llama_batch_allocr & balloc,
 | 
						|
            uint32_t n_ubatch,
 | 
						|
            bool embd_all) override;
 | 
						|
 | 
						|
    llama_memory_context_ptr init_full() override;
 | 
						|
 | 
						|
    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 | 
						|
 | 
						|
    bool get_can_shift() const override;
 | 
						|
 | 
						|
    void clear(bool data) override;
 | 
						|
 | 
						|
    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
 | 
						|
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
 | 
						|
    void seq_keep(llama_seq_id seq_id)                                                          override;
 | 
						|
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
 | 
						|
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 | 
						|
 | 
						|
    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
 | 
						|
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 | 
						|
 | 
						|
    // state write/load
 | 
						|
 | 
						|
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
 | 
						|
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
 | 
						|
 | 
						|
    //
 | 
						|
    // llama_kv_cache_unified_iswa specific API
 | 
						|
    //
 | 
						|
 | 
						|
    llama_kv_cache_unified * get_base() const;
 | 
						|
    llama_kv_cache_unified * get_swa () const;
 | 
						|
 | 
						|
private:
 | 
						|
    const llama_hparams & hparams;
 | 
						|
 | 
						|
    const bool unified;
 | 
						|
 | 
						|
    std::unique_ptr<llama_kv_cache_unified> kv_base;
 | 
						|
    std::unique_ptr<llama_kv_cache_unified> kv_swa;
 | 
						|
};
 | 
						|
 | 
						|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
 | 
						|
public:
 | 
						|
    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
 | 
						|
 | 
						|
    // used for errors
 | 
						|
    llama_kv_cache_unified_iswa_context(llama_memory_status status);
 | 
						|
 | 
						|
    // used to create a full-cache context
 | 
						|
    llama_kv_cache_unified_iswa_context(
 | 
						|
            llama_kv_cache_unified_iswa * kv);
 | 
						|
 | 
						|
    // used to create an update context
 | 
						|
    llama_kv_cache_unified_iswa_context(
 | 
						|
            llama_kv_cache_unified_iswa * kv,
 | 
						|
            llama_context * lctx,
 | 
						|
            bool optimize);
 | 
						|
 | 
						|
    // used to create a batch processing context from a batch
 | 
						|
    llama_kv_cache_unified_iswa_context(
 | 
						|
            llama_kv_cache_unified_iswa * kv,
 | 
						|
            slot_info_vec_t sinfos_base,
 | 
						|
            slot_info_vec_t sinfos_swa,
 | 
						|
            std::vector<llama_ubatch> ubatches);
 | 
						|
 | 
						|
    virtual ~llama_kv_cache_unified_iswa_context();
 | 
						|
 | 
						|
    //
 | 
						|
    // llama_memory_context_i
 | 
						|
    //
 | 
						|
 | 
						|
    bool next()  override;
 | 
						|
    bool apply() override;
 | 
						|
 | 
						|
    llama_memory_status  get_status() const override;
 | 
						|
    const llama_ubatch & get_ubatch() const override;
 | 
						|
 | 
						|
    //
 | 
						|
    // llama_kv_cache_unified_iswa_context specific API
 | 
						|
    //
 | 
						|
 | 
						|
    const llama_kv_cache_unified_context * get_base() const;
 | 
						|
    const llama_kv_cache_unified_context * get_swa()  const;
 | 
						|
 | 
						|
private:
 | 
						|
    //llama_kv_cache_unified_iswa * kv;
 | 
						|
 | 
						|
    // the index of the next ubatch to process
 | 
						|
    size_t i_next = 0;
 | 
						|
 | 
						|
    std::vector<llama_ubatch> ubatches;
 | 
						|
 | 
						|
    const llama_memory_context_ptr ctx_base;
 | 
						|
    const llama_memory_context_ptr ctx_swa;
 | 
						|
 | 
						|
    const llama_memory_status status;
 | 
						|
};
 |