mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	kv_cache : move state read/write to llama_kv_cache
ggml-ci
This commit is contained in:
		@@ -908,143 +908,6 @@ struct llama_data_write {
 | 
				
			|||||||
            write(ctx->embd, embeddings_size * sizeof(float));
 | 
					            write(ctx->embd, embeddings_size * sizeof(float));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
 | 
					 | 
				
			||||||
        for (const auto & range : cell_ranges) {
 | 
					 | 
				
			||||||
            for (uint32_t i = range.first; i < range.second; ++i) {
 | 
					 | 
				
			||||||
                const auto & cell = kv_self.cells[i];
 | 
					 | 
				
			||||||
                const llama_pos pos      = cell.pos;
 | 
					 | 
				
			||||||
                const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                write(&pos,      sizeof(pos));
 | 
					 | 
				
			||||||
                write(&n_seq_id, sizeof(n_seq_id));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (n_seq_id) {
 | 
					 | 
				
			||||||
                    for (auto seq_id : cell.seq_id) {
 | 
					 | 
				
			||||||
                        write(&seq_id, sizeof(seq_id));
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    void write_kv_cache_data(const llama_kv_cache & kv, const llama_hparams & hparams, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
 | 
					 | 
				
			||||||
        const uint32_t v_trans = kv.v_trans ? 1 : 0;
 | 
					 | 
				
			||||||
        const uint32_t n_layer = hparams.n_layer;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        write(&v_trans, sizeof(v_trans));
 | 
					 | 
				
			||||||
        write(&n_layer, sizeof(n_layer));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        std::vector<uint8_t> tmp_buf;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Iterate and write all the keys first, each row is a cell
 | 
					 | 
				
			||||||
        // Get whole range at a time
 | 
					 | 
				
			||||||
        for (uint32_t il = 0; il < n_layer; ++il) {
 | 
					 | 
				
			||||||
            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // Write key type
 | 
					 | 
				
			||||||
            const int32_t k_type_i = (int32_t)kv.k_l[il]->type;
 | 
					 | 
				
			||||||
            write(&k_type_i, sizeof(k_type_i));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // Write row size of key
 | 
					 | 
				
			||||||
            const uint64_t k_size_row = ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa);
 | 
					 | 
				
			||||||
            write(&k_size_row, sizeof(k_size_row));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // Read each range of cells of k_size length each into tmp_buf and write out
 | 
					 | 
				
			||||||
            for (const auto & range : cell_ranges) {
 | 
					 | 
				
			||||||
                const size_t range_size = range.second - range.first;
 | 
					 | 
				
			||||||
                const size_t buf_size = range_size * k_size_row;
 | 
					 | 
				
			||||||
                write_tensor_data(kv.k_l[il], range.first * k_size_row, buf_size);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (!kv.v_trans) {
 | 
					 | 
				
			||||||
            for (uint32_t il = 0; il < n_layer; ++il) {
 | 
					 | 
				
			||||||
                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Write value type
 | 
					 | 
				
			||||||
                const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
 | 
					 | 
				
			||||||
                write(&v_type_i, sizeof(v_type_i));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Write row size of value
 | 
					 | 
				
			||||||
                const uint64_t v_size_row = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa);
 | 
					 | 
				
			||||||
                write(&v_size_row, sizeof(v_size_row));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Read each range of cells of v_size length each into tmp_buf and write out
 | 
					 | 
				
			||||||
                for (const auto & range : cell_ranges) {
 | 
					 | 
				
			||||||
                    const size_t range_size = range.second - range.first;
 | 
					 | 
				
			||||||
                    const size_t buf_size = range_size * v_size_row;
 | 
					 | 
				
			||||||
                    write_tensor_data(kv.v_l[il], range.first * v_size_row, buf_size);
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            // When v is transposed, we also need the element size and get the element ranges from each row
 | 
					 | 
				
			||||||
            const uint32_t kv_size = kv.size;
 | 
					 | 
				
			||||||
            for (uint32_t il = 0; il < n_layer; ++il) {
 | 
					 | 
				
			||||||
                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Write value type
 | 
					 | 
				
			||||||
                const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
 | 
					 | 
				
			||||||
                write(&v_type_i, sizeof(v_type_i));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Write element size
 | 
					 | 
				
			||||||
                const uint32_t v_size_el = ggml_type_size(kv.v_l[il]->type);
 | 
					 | 
				
			||||||
                write(&v_size_el, sizeof(v_size_el));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Write GQA embedding size
 | 
					 | 
				
			||||||
                write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // For each row, we get the element values of each cell
 | 
					 | 
				
			||||||
                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
 | 
					 | 
				
			||||||
                    // Read each range of cells of v_size_el length each into tmp_buf and write out
 | 
					 | 
				
			||||||
                    for (const auto & range : cell_ranges) {
 | 
					 | 
				
			||||||
                        const size_t range_size = range.second - range.first;
 | 
					 | 
				
			||||||
                        const size_t src_offset = (range.first + j * kv_size) * v_size_el;
 | 
					 | 
				
			||||||
                        const size_t buf_size = range_size * v_size_el;
 | 
					 | 
				
			||||||
                        write_tensor_data(kv.v_l[il], src_offset, buf_size);
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    void write_kv_cache(const llama_kv_cache & kv, const llama_hparams & hparams, llama_seq_id seq_id = -1) {
 | 
					 | 
				
			||||||
        std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
 | 
					 | 
				
			||||||
        uint32_t cell_count = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Count the number of cells with the specified seq_id
 | 
					 | 
				
			||||||
        // Find all the ranges of cells with this seq id (or all, when -1)
 | 
					 | 
				
			||||||
        uint32_t cell_range_begin = kv.size;
 | 
					 | 
				
			||||||
        for (uint32_t i = 0; i < kv.size; ++i) {
 | 
					 | 
				
			||||||
            const auto & cell = kv.cells[i];
 | 
					 | 
				
			||||||
            if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
 | 
					 | 
				
			||||||
                ++cell_count;
 | 
					 | 
				
			||||||
                if (cell_range_begin == kv.size) {
 | 
					 | 
				
			||||||
                    cell_range_begin = i;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                if (cell_range_begin != kv.size) {
 | 
					 | 
				
			||||||
                    cell_ranges.emplace_back(cell_range_begin, i);
 | 
					 | 
				
			||||||
                    cell_range_begin = kv.size;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if (cell_range_begin != kv.size) {
 | 
					 | 
				
			||||||
            cell_ranges.emplace_back(cell_range_begin, kv.size);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
 | 
					 | 
				
			||||||
        uint32_t cell_count_check = 0;
 | 
					 | 
				
			||||||
        for (const auto & range : cell_ranges) {
 | 
					 | 
				
			||||||
            cell_count_check += range.second - range.first;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        GGML_ASSERT(cell_count == cell_count_check);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        write(&cell_count, sizeof(cell_count));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        write_kv_cache_meta(kv, cell_ranges, seq_id);
 | 
					 | 
				
			||||||
        write_kv_cache_data(kv, hparams, cell_ranges);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct llama_data_read {
 | 
					struct llama_data_read {
 | 
				
			||||||
@@ -1135,241 +998,6 @@ struct llama_data_read {
 | 
				
			|||||||
            read_to(ctx->embd, embeddings_size * sizeof(float));
 | 
					            read_to(ctx->embd, embeddings_size * sizeof(float));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    bool read_kv_cache_meta(llama_kv_cache & kv, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
 | 
					 | 
				
			||||||
        if (dest_seq_id != -1) {
 | 
					 | 
				
			||||||
            // single sequence
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            kv.seq_rm(dest_seq_id, -1, -1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            llama_sbatch sbatch;
 | 
					 | 
				
			||||||
            llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            batch.n_tokens = cell_count;
 | 
					 | 
				
			||||||
            batch.n_seq_tokens = cell_count;
 | 
					 | 
				
			||||||
            batch.n_seqs = 1;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for (uint32_t i = 0; i < cell_count; ++i) {
 | 
					 | 
				
			||||||
                llama_pos pos;
 | 
					 | 
				
			||||||
                uint32_t n_seq_id;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                read_to(&pos,      sizeof(pos));
 | 
					 | 
				
			||||||
                read_to(&n_seq_id, sizeof(n_seq_id));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (n_seq_id != 0) {
 | 
					 | 
				
			||||||
                    LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
 | 
					 | 
				
			||||||
                    return false;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                batch.pos[i] = pos;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            batch.n_seq_id[0] = 1;
 | 
					 | 
				
			||||||
            batch.seq_id[0] = &dest_seq_id;
 | 
					 | 
				
			||||||
            if (!kv.find_slot(batch)) {
 | 
					 | 
				
			||||||
                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
 | 
					 | 
				
			||||||
                return false;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
 | 
					 | 
				
			||||||
            // Assume that this is one contiguous block of cells
 | 
					 | 
				
			||||||
            GGML_ASSERT(kv.head + cell_count <= kv.size);
 | 
					 | 
				
			||||||
            GGML_ASSERT(kv.cells[kv.head].pos == batch.pos[0]);
 | 
					 | 
				
			||||||
            GGML_ASSERT(kv.cells[kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
 | 
					 | 
				
			||||||
            GGML_ASSERT(kv.cells[kv.head].has_seq_id(dest_seq_id));
 | 
					 | 
				
			||||||
            GGML_ASSERT(kv.cells[kv.head + cell_count - 1].has_seq_id(dest_seq_id));
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            // whole KV cache restore
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (cell_count > kv.size) {
 | 
					 | 
				
			||||||
                LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
 | 
					 | 
				
			||||||
                return false;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            kv.clear();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for (uint32_t i = 0; i < cell_count; ++i) {
 | 
					 | 
				
			||||||
                llama_kv_cell & cell = kv.cells[i];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                llama_pos pos;
 | 
					 | 
				
			||||||
                uint32_t  n_seq_id;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                read_to(&pos,      sizeof(pos));
 | 
					 | 
				
			||||||
                read_to(&n_seq_id, sizeof(n_seq_id));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                cell.pos = pos;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                for (uint32_t j = 0; j < n_seq_id; ++j) {
 | 
					 | 
				
			||||||
                    llama_seq_id seq_id;
 | 
					 | 
				
			||||||
                    read_to(&seq_id, sizeof(seq_id));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    // TODO: llama_kv_cache should have a notion of max sequences
 | 
					 | 
				
			||||||
                    //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
 | 
					 | 
				
			||||||
                    if (seq_id < 0) {
 | 
					 | 
				
			||||||
                        //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
 | 
					 | 
				
			||||||
                        LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
 | 
					 | 
				
			||||||
                        return false;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    cell.seq_id.insert(seq_id);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if (kv.recurrent) {
 | 
					 | 
				
			||||||
                        int32_t & tail = kv.cells[seq_id].tail;
 | 
					 | 
				
			||||||
                        if (tail != -1) {
 | 
					 | 
				
			||||||
                            LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
 | 
					 | 
				
			||||||
                            return false;
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        tail = i;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            kv.head = 0;
 | 
					 | 
				
			||||||
            kv.used = cell_count;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (kv.recurrent) {
 | 
					 | 
				
			||||||
            for (uint32_t i = 0; i < cell_count; ++i) {
 | 
					 | 
				
			||||||
                uint32_t cell_id = kv.head + i;
 | 
					 | 
				
			||||||
                // make sure the recurrent states will keep their restored state
 | 
					 | 
				
			||||||
                kv.cells[cell_id].src = cell_id;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return true;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    bool read_kv_cache_data(llama_kv_cache & kv, const llama_hparams & hparams, uint32_t cell_count) {
 | 
					 | 
				
			||||||
        uint32_t v_trans;
 | 
					 | 
				
			||||||
        uint32_t n_layer;
 | 
					 | 
				
			||||||
        read_to(&v_trans, sizeof(v_trans));
 | 
					 | 
				
			||||||
        read_to(&n_layer, sizeof(n_layer));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (n_layer != hparams.n_layer) {
 | 
					 | 
				
			||||||
            LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
 | 
					 | 
				
			||||||
            return false;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if (cell_count > kv.size) {
 | 
					 | 
				
			||||||
            LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv.size);
 | 
					 | 
				
			||||||
            return false;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if (kv.v_trans != (bool) v_trans) {
 | 
					 | 
				
			||||||
            LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
 | 
					 | 
				
			||||||
            return false;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
 | 
					 | 
				
			||||||
        for (uint32_t il = 0; il < n_layer; ++il) {
 | 
					 | 
				
			||||||
            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // Read type of key
 | 
					 | 
				
			||||||
            int32_t k_type_i_ref;
 | 
					 | 
				
			||||||
            read_to(&k_type_i_ref, sizeof(k_type_i_ref));
 | 
					 | 
				
			||||||
            const int32_t k_type_i = (int32_t)kv.k_l[il]->type;
 | 
					 | 
				
			||||||
            if (k_type_i != k_type_i_ref) {
 | 
					 | 
				
			||||||
                LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
 | 
					 | 
				
			||||||
                return false;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // Read row size of key
 | 
					 | 
				
			||||||
            uint64_t k_size_row_ref;
 | 
					 | 
				
			||||||
            read_to(&k_size_row_ref, sizeof(k_size_row_ref));
 | 
					 | 
				
			||||||
            const size_t k_size_row = ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa);
 | 
					 | 
				
			||||||
            if (k_size_row != k_size_row_ref) {
 | 
					 | 
				
			||||||
                LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
 | 
					 | 
				
			||||||
                return false;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (cell_count) {
 | 
					 | 
				
			||||||
                // Read and set the keys for the whole cell range
 | 
					 | 
				
			||||||
                ggml_backend_tensor_set(kv.k_l[il], read(cell_count * k_size_row), kv.head * k_size_row, cell_count * k_size_row);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (!kv.v_trans) {
 | 
					 | 
				
			||||||
            for (uint32_t il = 0; il < n_layer; ++il) {
 | 
					 | 
				
			||||||
                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Read type of value
 | 
					 | 
				
			||||||
                int32_t v_type_i_ref;
 | 
					 | 
				
			||||||
                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
 | 
					 | 
				
			||||||
                const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
 | 
					 | 
				
			||||||
                if (v_type_i != v_type_i_ref) {
 | 
					 | 
				
			||||||
                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
 | 
					 | 
				
			||||||
                    return false;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Read row size of value
 | 
					 | 
				
			||||||
                uint64_t v_size_row_ref;
 | 
					 | 
				
			||||||
                read_to(&v_size_row_ref, sizeof(v_size_row_ref));
 | 
					 | 
				
			||||||
                const size_t v_size_row = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa);
 | 
					 | 
				
			||||||
                if (v_size_row != v_size_row_ref) {
 | 
					 | 
				
			||||||
                    LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
 | 
					 | 
				
			||||||
                    return false;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (cell_count) {
 | 
					 | 
				
			||||||
                    // Read and set the values for the whole cell range
 | 
					 | 
				
			||||||
                    ggml_backend_tensor_set(kv.v_l[il], read(cell_count * v_size_row), kv.head * v_size_row, cell_count * v_size_row);
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            // For each layer, read the values for each cell (transposed)
 | 
					 | 
				
			||||||
            for (uint32_t il = 0; il < n_layer; ++il) {
 | 
					 | 
				
			||||||
                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Read type of value
 | 
					 | 
				
			||||||
                int32_t v_type_i_ref;
 | 
					 | 
				
			||||||
                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
 | 
					 | 
				
			||||||
                const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
 | 
					 | 
				
			||||||
                if (v_type_i != v_type_i_ref) {
 | 
					 | 
				
			||||||
                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
 | 
					 | 
				
			||||||
                    return false;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Read element size of value
 | 
					 | 
				
			||||||
                uint32_t v_size_el_ref;
 | 
					 | 
				
			||||||
                read_to(&v_size_el_ref, sizeof(v_size_el_ref));
 | 
					 | 
				
			||||||
                const size_t v_size_el = ggml_type_size(kv.v_l[il]->type);
 | 
					 | 
				
			||||||
                if (v_size_el != v_size_el_ref) {
 | 
					 | 
				
			||||||
                    LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
 | 
					 | 
				
			||||||
                    return false;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // Read GQA embedding size
 | 
					 | 
				
			||||||
                uint32_t n_embd_v_gqa_ref;
 | 
					 | 
				
			||||||
                read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
 | 
					 | 
				
			||||||
                if (n_embd_v_gqa != n_embd_v_gqa_ref) {
 | 
					 | 
				
			||||||
                    LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
 | 
					 | 
				
			||||||
                    return false;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (cell_count) {
 | 
					 | 
				
			||||||
                    // For each row in the transposed matrix, read the values for the whole cell range
 | 
					 | 
				
			||||||
                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
 | 
					 | 
				
			||||||
                        const size_t dst_offset = (kv.head + j * kv.size) * v_size_el;
 | 
					 | 
				
			||||||
                        ggml_backend_tensor_set(kv.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        return true;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    void read_kv_cache(llama_kv_cache & kv, const llama_hparams & hparams, llama_seq_id seq_id = -1) {
 | 
					 | 
				
			||||||
        uint32_t cell_count;
 | 
					 | 
				
			||||||
        read_to(&cell_count, sizeof(cell_count));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        bool res = read_kv_cache_meta(kv, cell_count, seq_id) && read_kv_cache_data(kv, hparams, cell_count);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (!res) {
 | 
					 | 
				
			||||||
            if (seq_id == -1) {
 | 
					 | 
				
			||||||
                kv.clear();
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                kv.seq_rm(seq_id, -1, -1);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            throw std::runtime_error("failed to restore kv cache");
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct llama_data_write_dummy : llama_data_write {
 | 
					struct llama_data_write_dummy : llama_data_write {
 | 
				
			||||||
@@ -1518,7 +1146,18 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
 | 
				
			|||||||
    data_ctx.write_logits(ctx);
 | 
					    data_ctx.write_logits(ctx);
 | 
				
			||||||
    data_ctx.write_embeddings(ctx);
 | 
					    data_ctx.write_embeddings(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    data_ctx.write_kv_cache(ctx->kv_self, ctx->model.hparams);
 | 
					    llama_kv_cache::io io = {
 | 
				
			||||||
 | 
					        /* .write =*/ [&](const void * src, size_t size) {
 | 
				
			||||||
 | 
					            data_ctx.write(src, size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        /* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
 | 
				
			||||||
 | 
					            data_ctx.write_tensor_data(tensor, offset, size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        /* .read    =*/ nullptr,
 | 
				
			||||||
 | 
					        /* .read_to =*/ nullptr,
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ctx->kv_self.state_write(io, ctx->model.hparams);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return data_ctx.get_size_written();
 | 
					    return data_ctx.get_size_written();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -1555,7 +1194,18 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
 | 
				
			|||||||
    data_ctx.read_logits(ctx);
 | 
					    data_ctx.read_logits(ctx);
 | 
				
			||||||
    data_ctx.read_embeddings(ctx);
 | 
					    data_ctx.read_embeddings(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    data_ctx.read_kv_cache(ctx->kv_self, ctx->model.hparams);
 | 
					    llama_kv_cache::io io = {
 | 
				
			||||||
 | 
					        /* .write =*/ nullptr,
 | 
				
			||||||
 | 
					        /* .write_tensor_data =*/ nullptr,
 | 
				
			||||||
 | 
					        /* .read =*/ [&](size_t size) {
 | 
				
			||||||
 | 
					            return data_ctx.read(size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        /* .read_to =*/ [&](void * dst, size_t size) {
 | 
				
			||||||
 | 
					            data_ctx.read_to(dst, size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ctx->kv_self.state_read(io, ctx->model.hparams);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return data_ctx.get_size_read();
 | 
					    return data_ctx.get_size_read();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -1651,7 +1301,18 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
 | 
				
			|||||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
 | 
					static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
 | 
				
			||||||
    llama_synchronize(ctx);
 | 
					    llama_synchronize(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    data_ctx.write_kv_cache(ctx->kv_self, ctx->model.hparams, seq_id);
 | 
					    llama_kv_cache::io io = {
 | 
				
			||||||
 | 
					        /* .write =*/ [&](const void * src, size_t size) {
 | 
				
			||||||
 | 
					            data_ctx.write(src, size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        /* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
 | 
				
			||||||
 | 
					            data_ctx.write_tensor_data(tensor, offset, size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        /* .read =*/    nullptr,
 | 
				
			||||||
 | 
					        /* .read_to =*/ nullptr,
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ctx->kv_self.state_write(io, ctx->model.hparams, seq_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return data_ctx.get_size_written();
 | 
					    return data_ctx.get_size_written();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -1674,7 +1335,18 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
 | 
				
			|||||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
 | 
					static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
 | 
				
			||||||
    llama_synchronize(ctx);
 | 
					    llama_synchronize(ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    data_ctx.read_kv_cache(ctx->kv_self, ctx->model.hparams, dest_seq_id);
 | 
					    llama_kv_cache::io io = {
 | 
				
			||||||
 | 
					        /* .write =*/ nullptr,
 | 
				
			||||||
 | 
					        /* .write_tensor_data =*/ nullptr,
 | 
				
			||||||
 | 
					        /* .read =*/ [&](size_t size) {
 | 
				
			||||||
 | 
					            return data_ctx.read(size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        /* .read_to =*/ [&](void * dst, size_t size) {
 | 
				
			||||||
 | 
					            data_ctx.read_to(dst, size);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ctx->kv_self.state_read(io, ctx->model.hparams, dest_seq_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return data_ctx.get_size_read();
 | 
					    return data_ctx.get_size_read();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,6 +8,7 @@
 | 
				
			|||||||
#include <algorithm>
 | 
					#include <algorithm>
 | 
				
			||||||
#include <limits>
 | 
					#include <limits>
 | 
				
			||||||
#include <map>
 | 
					#include <map>
 | 
				
			||||||
 | 
					#include <stdexcept>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
 | 
					static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -696,6 +697,383 @@ size_t llama_kv_cache::size_v_bytes() const {
 | 
				
			|||||||
    return size_v_bytes;
 | 
					    return size_v_bytes;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) const {
 | 
				
			||||||
 | 
					    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
 | 
				
			||||||
 | 
					    uint32_t cell_count = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Count the number of cells with the specified seq_id
 | 
				
			||||||
 | 
					    // Find all the ranges of cells with this seq id (or all, when -1)
 | 
				
			||||||
 | 
					    uint32_t cell_range_begin = size;
 | 
				
			||||||
 | 
					    for (uint32_t i = 0; i < size; ++i) {
 | 
				
			||||||
 | 
					        const auto & cell = cells[i];
 | 
				
			||||||
 | 
					        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
 | 
				
			||||||
 | 
					            ++cell_count;
 | 
				
			||||||
 | 
					            if (cell_range_begin == size) {
 | 
				
			||||||
 | 
					                cell_range_begin = i;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            if (cell_range_begin != size) {
 | 
				
			||||||
 | 
					                cell_ranges.emplace_back(cell_range_begin, i);
 | 
				
			||||||
 | 
					                cell_range_begin = size;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (cell_range_begin != size) {
 | 
				
			||||||
 | 
					        cell_ranges.emplace_back(cell_range_begin, size);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
 | 
				
			||||||
 | 
					    uint32_t cell_count_check = 0;
 | 
				
			||||||
 | 
					    for (const auto & range : cell_ranges) {
 | 
				
			||||||
 | 
					        cell_count_check += range.second - range.first;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    GGML_ASSERT(cell_count == cell_count_check);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    io.write(&cell_count, sizeof(cell_count));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    state_write_meta(io, cell_ranges, seq_id);
 | 
				
			||||||
 | 
					    state_write_data(io, cell_ranges, hparams);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) {
 | 
				
			||||||
 | 
					    uint32_t cell_count;
 | 
				
			||||||
 | 
					    io.read_to(&cell_count, sizeof(cell_count));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bool res = true;
 | 
				
			||||||
 | 
					    res = res && state_read_meta(io, cell_count, seq_id);
 | 
				
			||||||
 | 
					    res = res && state_read_data(io, hparams, cell_count);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (!res) {
 | 
				
			||||||
 | 
					        if (seq_id == -1) {
 | 
				
			||||||
 | 
					            clear();
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            seq_rm(seq_id, -1, -1);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        throw std::runtime_error("failed to restore kv cache");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void llama_kv_cache::state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
 | 
				
			||||||
 | 
					    for (const auto & range : cell_ranges) {
 | 
				
			||||||
 | 
					        for (uint32_t i = range.first; i < range.second; ++i) {
 | 
				
			||||||
 | 
					            const auto & cell = cells[i];
 | 
				
			||||||
 | 
					            const llama_pos pos      = cell.pos;
 | 
				
			||||||
 | 
					            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            io.write(&pos,      sizeof(pos));
 | 
				
			||||||
 | 
					            io.write(&n_seq_id, sizeof(n_seq_id));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (n_seq_id) {
 | 
				
			||||||
 | 
					                for (auto seq_id : cell.seq_id) {
 | 
				
			||||||
 | 
					                    io.write(&seq_id, sizeof(seq_id));
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const {
 | 
				
			||||||
 | 
					    const uint32_t v_trans = this->v_trans ? 1 : 0;
 | 
				
			||||||
 | 
					    const uint32_t n_layer = hparams.n_layer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    io.write(&v_trans, sizeof(v_trans));
 | 
				
			||||||
 | 
					    io.write(&n_layer, sizeof(n_layer));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<uint8_t> tmp_buf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Iterate and write all the keys first, each row is a cell
 | 
				
			||||||
 | 
					    // Get whole range at a time
 | 
				
			||||||
 | 
					    for (uint32_t il = 0; il < n_layer; ++il) {
 | 
				
			||||||
 | 
					        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Write key type
 | 
				
			||||||
 | 
					        const int32_t k_type_i = (int32_t)k_l[il]->type;
 | 
				
			||||||
 | 
					        io.write(&k_type_i, sizeof(k_type_i));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Write row size of key
 | 
				
			||||||
 | 
					        const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
 | 
				
			||||||
 | 
					        io.write(&k_size_row, sizeof(k_size_row));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Read each range of cells of k_size length each into tmp_buf and write out
 | 
				
			||||||
 | 
					        for (const auto & range : cell_ranges) {
 | 
				
			||||||
 | 
					            const size_t range_size = range.second - range.first;
 | 
				
			||||||
 | 
					            const size_t buf_size = range_size * k_size_row;
 | 
				
			||||||
 | 
					            io.write_tensor_data(k_l[il], range.first * k_size_row, buf_size);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (!v_trans) {
 | 
				
			||||||
 | 
					        for (uint32_t il = 0; il < n_layer; ++il) {
 | 
				
			||||||
 | 
					            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Write value type
 | 
				
			||||||
 | 
					            const int32_t v_type_i = (int32_t)v_l[il]->type;
 | 
				
			||||||
 | 
					            io.write(&v_type_i, sizeof(v_type_i));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Write row size of value
 | 
				
			||||||
 | 
					            const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
 | 
				
			||||||
 | 
					            io.write(&v_size_row, sizeof(v_size_row));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Read each range of cells of v_size length each into tmp_buf and write out
 | 
				
			||||||
 | 
					            for (const auto & range : cell_ranges) {
 | 
				
			||||||
 | 
					                const size_t range_size = range.second - range.first;
 | 
				
			||||||
 | 
					                const size_t buf_size = range_size * v_size_row;
 | 
				
			||||||
 | 
					                io.write_tensor_data(v_l[il], range.first * v_size_row, buf_size);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					        // When v is transposed, we also need the element size and get the element ranges from each row
 | 
				
			||||||
 | 
					        const uint32_t kv_size = size;
 | 
				
			||||||
 | 
					        for (uint32_t il = 0; il < n_layer; ++il) {
 | 
				
			||||||
 | 
					            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Write value type
 | 
				
			||||||
 | 
					            const int32_t v_type_i = (int32_t)v_l[il]->type;
 | 
				
			||||||
 | 
					            io.write(&v_type_i, sizeof(v_type_i));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Write element size
 | 
				
			||||||
 | 
					            const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
 | 
				
			||||||
 | 
					            io.write(&v_size_el, sizeof(v_size_el));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Write GQA embedding size
 | 
				
			||||||
 | 
					            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // For each row, we get the element values of each cell
 | 
				
			||||||
 | 
					            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
 | 
				
			||||||
 | 
					                // Read each range of cells of v_size_el length each into tmp_buf and write out
 | 
				
			||||||
 | 
					                for (const auto & range : cell_ranges) {
 | 
				
			||||||
 | 
					                    const size_t range_size = range.second - range.first;
 | 
				
			||||||
 | 
					                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
 | 
				
			||||||
 | 
					                    const size_t buf_size = range_size * v_size_el;
 | 
				
			||||||
 | 
					                    io.write_tensor_data(v_l[il], src_offset, buf_size);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					bool llama_kv_cache::state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
 | 
				
			||||||
 | 
					    if (dest_seq_id != -1) {
 | 
				
			||||||
 | 
					        // single sequence
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        seq_rm(dest_seq_id, -1, -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        llama_sbatch sbatch;
 | 
				
			||||||
 | 
					        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch.n_tokens = cell_count;
 | 
				
			||||||
 | 
					        batch.n_seq_tokens = cell_count;
 | 
				
			||||||
 | 
					        batch.n_seqs = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (uint32_t i = 0; i < cell_count; ++i) {
 | 
				
			||||||
 | 
					            llama_pos pos;
 | 
				
			||||||
 | 
					            uint32_t n_seq_id;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            io.read_to(&pos,      sizeof(pos));
 | 
				
			||||||
 | 
					            io.read_to(&n_seq_id, sizeof(n_seq_id));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (n_seq_id != 0) {
 | 
				
			||||||
 | 
					                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
 | 
				
			||||||
 | 
					                return false;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch.pos[i] = pos;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        batch.n_seq_id[0] = 1;
 | 
				
			||||||
 | 
					        batch.seq_id[0] = &dest_seq_id;
 | 
				
			||||||
 | 
					        if (!find_slot(batch)) {
 | 
				
			||||||
 | 
					            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
 | 
				
			||||||
 | 
					        // Assume that this is one contiguous block of cells
 | 
				
			||||||
 | 
					        GGML_ASSERT(head + cell_count <= size);
 | 
				
			||||||
 | 
					        GGML_ASSERT(cells[head].pos == batch.pos[0]);
 | 
				
			||||||
 | 
					        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
 | 
				
			||||||
 | 
					        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
 | 
				
			||||||
 | 
					        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					        // whole KV cache restore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (cell_count > size) {
 | 
				
			||||||
 | 
					            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        clear();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (uint32_t i = 0; i < cell_count; ++i) {
 | 
				
			||||||
 | 
					            llama_kv_cell & cell = cells[i];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            llama_pos pos;
 | 
				
			||||||
 | 
					            uint32_t  n_seq_id;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            io.read_to(&pos,      sizeof(pos));
 | 
				
			||||||
 | 
					            io.read_to(&n_seq_id, sizeof(n_seq_id));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            cell.pos = pos;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for (uint32_t j = 0; j < n_seq_id; ++j) {
 | 
				
			||||||
 | 
					                llama_seq_id seq_id;
 | 
				
			||||||
 | 
					                io.read_to(&seq_id, sizeof(seq_id));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // TODO: llama_kv_cache should have a notion of max sequences
 | 
				
			||||||
 | 
					                //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
 | 
				
			||||||
 | 
					                if (seq_id < 0) {
 | 
				
			||||||
 | 
					                    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
 | 
				
			||||||
 | 
					                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
 | 
				
			||||||
 | 
					                    return false;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                cell.seq_id.insert(seq_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (recurrent) {
 | 
				
			||||||
 | 
					                    int32_t & tail = cells[seq_id].tail;
 | 
				
			||||||
 | 
					                    if (tail != -1) {
 | 
				
			||||||
 | 
					                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
 | 
				
			||||||
 | 
					                        return false;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    tail = i;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        head = 0;
 | 
				
			||||||
 | 
					        used = cell_count;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (recurrent) {
 | 
				
			||||||
 | 
					        for (uint32_t i = 0; i < cell_count; ++i) {
 | 
				
			||||||
 | 
					            uint32_t cell_id = head + i;
 | 
				
			||||||
 | 
					            // make sure the recurrent states will keep their restored state
 | 
				
			||||||
 | 
					            cells[cell_id].src = cell_id;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					bool llama_kv_cache::state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count) {
 | 
				
			||||||
 | 
					    uint32_t v_trans;
 | 
				
			||||||
 | 
					    uint32_t n_layer;
 | 
				
			||||||
 | 
					    io.read_to(&v_trans, sizeof(v_trans));
 | 
				
			||||||
 | 
					    io.read_to(&n_layer, sizeof(n_layer));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (n_layer != hparams.n_layer) {
 | 
				
			||||||
 | 
					        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (cell_count > size) {
 | 
				
			||||||
 | 
					        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (v_trans != (bool) v_trans) {
 | 
				
			||||||
 | 
					        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
 | 
				
			||||||
 | 
					    for (uint32_t il = 0; il < n_layer; ++il) {
 | 
				
			||||||
 | 
					        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Read type of key
 | 
				
			||||||
 | 
					        int32_t k_type_i_ref;
 | 
				
			||||||
 | 
					        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
 | 
				
			||||||
 | 
					        const int32_t k_type_i = (int32_t) k_l[il]->type;
 | 
				
			||||||
 | 
					        if (k_type_i != k_type_i_ref) {
 | 
				
			||||||
 | 
					            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Read row size of key
 | 
				
			||||||
 | 
					        uint64_t k_size_row_ref;
 | 
				
			||||||
 | 
					        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
 | 
				
			||||||
 | 
					        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
 | 
				
			||||||
 | 
					        if (k_size_row != k_size_row_ref) {
 | 
				
			||||||
 | 
					            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (cell_count) {
 | 
				
			||||||
 | 
					            // Read and set the keys for the whole cell range
 | 
				
			||||||
 | 
					            ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (!v_trans) {
 | 
				
			||||||
 | 
					        for (uint32_t il = 0; il < n_layer; ++il) {
 | 
				
			||||||
 | 
					            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Read type of value
 | 
				
			||||||
 | 
					            int32_t v_type_i_ref;
 | 
				
			||||||
 | 
					            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
 | 
				
			||||||
 | 
					            const int32_t v_type_i = (int32_t)v_l[il]->type;
 | 
				
			||||||
 | 
					            if (v_type_i != v_type_i_ref) {
 | 
				
			||||||
 | 
					                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
 | 
				
			||||||
 | 
					                return false;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Read row size of value
 | 
				
			||||||
 | 
					            uint64_t v_size_row_ref;
 | 
				
			||||||
 | 
					            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
 | 
				
			||||||
 | 
					            const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
 | 
				
			||||||
 | 
					            if (v_size_row != v_size_row_ref) {
 | 
				
			||||||
 | 
					                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
 | 
				
			||||||
 | 
					                return false;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (cell_count) {
 | 
				
			||||||
 | 
					                // Read and set the values for the whole cell range
 | 
				
			||||||
 | 
					                ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					        // For each layer, read the values for each cell (transposed)
 | 
				
			||||||
 | 
					        for (uint32_t il = 0; il < n_layer; ++il) {
 | 
				
			||||||
 | 
					            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Read type of value
 | 
				
			||||||
 | 
					            int32_t v_type_i_ref;
 | 
				
			||||||
 | 
					            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
 | 
				
			||||||
 | 
					            const int32_t v_type_i = (int32_t)v_l[il]->type;
 | 
				
			||||||
 | 
					            if (v_type_i != v_type_i_ref) {
 | 
				
			||||||
 | 
					                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
 | 
				
			||||||
 | 
					                return false;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Read element size of value
 | 
				
			||||||
 | 
					            uint32_t v_size_el_ref;
 | 
				
			||||||
 | 
					            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
 | 
				
			||||||
 | 
					            const size_t v_size_el = ggml_type_size(v_l[il]->type);
 | 
				
			||||||
 | 
					            if (v_size_el != v_size_el_ref) {
 | 
				
			||||||
 | 
					                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
 | 
				
			||||||
 | 
					                return false;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // Read GQA embedding size
 | 
				
			||||||
 | 
					            uint32_t n_embd_v_gqa_ref;
 | 
				
			||||||
 | 
					            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
 | 
				
			||||||
 | 
					            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
 | 
				
			||||||
 | 
					                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
 | 
				
			||||||
 | 
					                return false;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (cell_count) {
 | 
				
			||||||
 | 
					                // For each row in the transposed matrix, read the values for the whole cell range
 | 
				
			||||||
 | 
					                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
 | 
				
			||||||
 | 
					                    const size_t dst_offset = (head + j * size) * v_size_el;
 | 
				
			||||||
 | 
					                    ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/////////////
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_kv_cache_clear(llama_kv_cache * kv) {
 | 
					void llama_kv_cache_clear(llama_kv_cache * kv) {
 | 
				
			||||||
    kv->clear();
 | 
					    kv->clear();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#include <set>
 | 
					#include <set>
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					#include <functional>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct llama_cparams;
 | 
					struct llama_cparams;
 | 
				
			||||||
 | 
					struct llama_hparams;
 | 
				
			||||||
struct llama_ubatch;
 | 
					struct llama_ubatch;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct llama_kv_cell {
 | 
					struct llama_kv_cell {
 | 
				
			||||||
@@ -45,6 +47,7 @@ struct llama_kv_cache_slot_info {
 | 
				
			|||||||
// ring-buffer of cached KV data
 | 
					// ring-buffer of cached KV data
 | 
				
			||||||
// TODO: pimpl
 | 
					// TODO: pimpl
 | 
				
			||||||
// TODO: add notion of max sequences
 | 
					// TODO: add notion of max sequences
 | 
				
			||||||
 | 
					// TODO: add llama_hparams &
 | 
				
			||||||
struct llama_kv_cache {
 | 
					struct llama_kv_cache {
 | 
				
			||||||
    bool has_shift = false;
 | 
					    bool has_shift = false;
 | 
				
			||||||
    bool do_defrag = false;
 | 
					    bool do_defrag = false;
 | 
				
			||||||
@@ -111,12 +114,29 @@ struct llama_kv_cache {
 | 
				
			|||||||
    size_t size_k_bytes() const;
 | 
					    size_t size_k_bytes() const;
 | 
				
			||||||
    size_t size_v_bytes() const;
 | 
					    size_t size_v_bytes() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    struct io {
 | 
				
			||||||
 | 
					        std::function<void(const void * src, size_t size)> write;
 | 
				
			||||||
 | 
					        std::function<void(const struct ggml_tensor * tensor, size_t offset, size_t size)> write_tensor_data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::function<const uint8_t * (size_t size)> read;
 | 
				
			||||||
 | 
					        std::function<void(void * dst, size_t size)> read_to;
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    void state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
 | 
				
			||||||
 | 
					    void state_read (const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
    ggml_type type_k = GGML_TYPE_F16;
 | 
					    ggml_type type_k = GGML_TYPE_F16;
 | 
				
			||||||
    ggml_type type_v = GGML_TYPE_F16;
 | 
					    ggml_type type_v = GGML_TYPE_F16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<ggml_context_ptr> ctxs;
 | 
					    std::vector<ggml_context_ptr> ctxs;
 | 
				
			||||||
    std::vector<ggml_backend_buffer_ptr> bufs;
 | 
					    std::vector<ggml_backend_buffer_ptr> bufs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    void state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
 | 
				
			||||||
 | 
					    void state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bool state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
 | 
				
			||||||
 | 
					    bool state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count);
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user