mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										463
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										463
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -14907,9 +14907,33 @@ void llama_kv_cache_update(struct llama_context * ctx) { | ||||
|     llama_kv_cache_update_internal(*ctx); | ||||
| } | ||||
|  | ||||
| // deprecated | ||||
| size_t llama_get_state_size(const struct llama_context * ctx) { | ||||
|     return llama_state_get_size(ctx); | ||||
| } | ||||
|  | ||||
| // deprecated | ||||
| size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | ||||
|     return llama_state_get_data(ctx, dst); | ||||
| } | ||||
|  | ||||
| // deprecated | ||||
| size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { | ||||
|     return llama_state_set_data(ctx, src); | ||||
| } | ||||
|  | ||||
| // deprecated | ||||
| bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { | ||||
|     return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); | ||||
| } | ||||
|  | ||||
| // deprecated | ||||
| bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { | ||||
|     return llama_state_save_file(ctx, path_session, tokens, n_token_count); | ||||
| } | ||||
|  | ||||
| // Returns the *maximum* size of the state | ||||
| size_t llama_get_state_size(const struct llama_context * ctx) { | ||||
| size_t llama_state_get_size(const struct llama_context * ctx) { | ||||
|     const auto & cparams = ctx->cparams; | ||||
|     const auto & hparams = ctx->model.hparams; | ||||
|  | ||||
| @@ -14997,15 +15021,15 @@ struct llama_data_file_context : llama_data_context { | ||||
|  * file context: | ||||
|  * llama_file file("/path", "wb"); | ||||
|  * llama_data_file_context data_ctx(&file); | ||||
|  * llama_copy_state_data(ctx, &data_ctx); | ||||
|  * llama_state_get_data(ctx, &data_ctx); | ||||
|  * | ||||
|  * buffer context: | ||||
|  * std::vector<uint8_t> buf(max_size, 0); | ||||
|  * llama_data_buffer_context data_ctx(&buf.data()); | ||||
|  * llama_copy_state_data(ctx, &data_ctx); | ||||
|  * llama_state_get_data(ctx, &data_ctx); | ||||
|  * | ||||
| */ | ||||
| static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { | ||||
| static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { | ||||
|     // copy rng | ||||
|     { | ||||
|         std::ostringstream rng_ss; | ||||
| @@ -15149,15 +15173,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat | ||||
|     } | ||||
| } | ||||
|  | ||||
| size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | ||||
| size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) { | ||||
|     llama_data_buffer_context data_ctx(dst); | ||||
|     llama_copy_state_data_internal(ctx, &data_ctx); | ||||
|     llama_state_get_data_internal(ctx, &data_ctx); | ||||
|  | ||||
|     return data_ctx.get_size_written(); | ||||
| } | ||||
|  | ||||
| // Sets the state reading from the specified source address | ||||
| size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { | ||||
| size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { | ||||
|     const uint8_t * inp = src; | ||||
|  | ||||
|     // set rng | ||||
| @@ -15309,14 +15333,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { | ||||
|     } | ||||
|  | ||||
|     const size_t nread    = inp - src; | ||||
|     const size_t max_size = llama_get_state_size(ctx); | ||||
|     const size_t max_size = llama_state_get_size(ctx); | ||||
|  | ||||
|     GGML_ASSERT(nread <= max_size); | ||||
|  | ||||
|     return nread; | ||||
| } | ||||
|  | ||||
| static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { | ||||
| static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { | ||||
|     llama_file file(path_session, "rb"); | ||||
|  | ||||
|     // sanity checks | ||||
| @@ -15354,7 +15378,7 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c | ||||
|     // restore the context state | ||||
|     { | ||||
|         const size_t n_state_size_cur = file.size - file.tell(); | ||||
|         const size_t n_state_size_max = llama_get_state_size(ctx); | ||||
|         const size_t n_state_size_max = llama_state_get_size(ctx); | ||||
|  | ||||
|         if (n_state_size_cur > n_state_size_max) { | ||||
|             LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); | ||||
| @@ -15364,22 +15388,22 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c | ||||
|         std::vector<uint8_t> state_data(n_state_size_max); | ||||
|         file.read_raw(state_data.data(), n_state_size_cur); | ||||
|  | ||||
|         llama_set_state_data(ctx, state_data.data()); | ||||
|         llama_state_set_data(ctx, state_data.data()); | ||||
|     } | ||||
|  | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { | ||||
| bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { | ||||
|     try { | ||||
|         return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); | ||||
|         return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); | ||||
|     } catch (const std::exception & err) { | ||||
|         LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); | ||||
|         return false; | ||||
|     } | ||||
| } | ||||
|  | ||||
| bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { | ||||
| static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { | ||||
|     llama_file file(path_session, "wb"); | ||||
|  | ||||
|     file.write_u32(LLAMA_SESSION_MAGIC); | ||||
| @@ -15393,11 +15417,420 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi | ||||
|  | ||||
|     // save the context state using stream saving | ||||
|     llama_data_file_context data_ctx(&file); | ||||
|     llama_copy_state_data_internal(ctx, &data_ctx); | ||||
|     llama_state_get_data_internal(ctx, &data_ctx); | ||||
|  | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { | ||||
|     try { | ||||
|         return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); | ||||
|     } catch (const std::exception & err) { | ||||
|         LLAMA_LOG_ERROR("error saving session file: %s\n", err.what()); | ||||
|         return false; | ||||
|     } | ||||
| } | ||||
|  | ||||
| size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) { | ||||
|     // save the size of size_t as a uint32_t for safety check | ||||
|     const size_t size_t_size_size = sizeof(uint32_t); | ||||
|  | ||||
|     // other values | ||||
|     const size_t s_cell_count_size = sizeof(uint32_t); | ||||
|     const size_t s_layer_count_size = sizeof(uint32_t); | ||||
|     const size_t n_embd_v_gqa_size = sizeof(uint32_t); | ||||
|  | ||||
|     size_t s_cell_count = 0; | ||||
|     size_t s_cell_data_size = 0; | ||||
|     const auto & kv_self = ctx->kv_self; | ||||
|     const auto & hparams = ctx->model.hparams; | ||||
|  | ||||
|     const uint32_t n_layer = hparams.n_layer; | ||||
|     const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); | ||||
|     const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < kv_self.size; ++i) { | ||||
|         const auto & cell = kv_self.cells[i]; | ||||
|         if (cell.seq_id.count(seq_id) > 0) { | ||||
|             ++s_cell_count; | ||||
|             s_cell_data_size += sizeof(llama_pos); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     for (int il = 0; il < (int)n_layer; ++il) { | ||||
|         // types of keys and values | ||||
|         s_cell_data_size += sizeof(int32_t) * 2; | ||||
|         // k_size_row and v_size_el values of layer | ||||
|         s_cell_data_size += sizeof(size_t) * 2; | ||||
|  | ||||
|         // keys | ||||
|         const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); | ||||
|         s_cell_data_size += k_size_row * s_cell_count; | ||||
|  | ||||
|         // values (transposed) | ||||
|         const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); | ||||
|         s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa; | ||||
|     } | ||||
|  | ||||
|     const size_t s_total = ( | ||||
|         size_t_size_size + | ||||
|         s_cell_count_size + | ||||
|         s_layer_count_size + | ||||
|         n_embd_v_gqa_size + | ||||
|         s_cell_data_size | ||||
|         ); | ||||
|  | ||||
|     return s_total; | ||||
| } | ||||
|  | ||||
| static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) { | ||||
|     const auto & kv_self = ctx->kv_self; | ||||
|     GGML_ASSERT(!kv_self.recurrent); // not implemented | ||||
|  | ||||
|     // Save the size of size_t as a uint32_t for safety check | ||||
|     const uint32_t size_t_size = sizeof(size_t); | ||||
|     data_ctx.write(&size_t_size, sizeof(size_t_size)); | ||||
|  | ||||
|     std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive | ||||
|     uint32_t cell_count = 0; | ||||
|  | ||||
|     // Count the number of cells with the specified seq_id | ||||
|     // Find all the ranges of cells with this seq id | ||||
|     { | ||||
|         uint32_t cell_range_begin = kv_self.size; | ||||
|         for (uint32_t i = 0; i < kv_self.size; ++i) { | ||||
|             const auto & cell = kv_self.cells[i]; | ||||
|             if (cell.has_seq_id(seq_id)) { | ||||
|                 ++cell_count; | ||||
|                 if (cell_range_begin == kv_self.size) { | ||||
|                     cell_range_begin = i; | ||||
|                 } | ||||
|             } | ||||
|             else { | ||||
|                 if (cell_range_begin != kv_self.size) { | ||||
|                     cell_ranges.push_back({ cell_range_begin, i }); | ||||
|                     cell_range_begin = kv_self.size; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         if (cell_range_begin != kv_self.size) { | ||||
|             cell_ranges.push_back({ cell_range_begin, kv_self.size }); | ||||
|         } | ||||
|  | ||||
|         // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count | ||||
|         uint32_t cell_count_check = 0; | ||||
|         for (const auto & range : cell_ranges) { | ||||
|             cell_count_check += range.second - range.first; | ||||
|         } | ||||
|         GGML_ASSERT(cell_count == cell_count_check); | ||||
|     } | ||||
|  | ||||
|     // Write the cell count | ||||
|     data_ctx.write(&cell_count, sizeof(cell_count)); | ||||
|  | ||||
|     const auto & hparams = ctx->model.hparams; | ||||
|     const uint32_t n_layer = hparams.n_layer; | ||||
|     const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); | ||||
|     const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); | ||||
|  | ||||
|     // Write the layer count | ||||
|     data_ctx.write(&n_layer, sizeof(n_layer)); | ||||
|  | ||||
|     // Write n_embd_v_gqa | ||||
|     data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); | ||||
|  | ||||
|     // Iterate the ranges and write all the pos (this is the token position in the prompt) | ||||
|     for (const auto & range : cell_ranges) { | ||||
|         for (uint32_t i = range.first; i < range.second; ++i) { | ||||
|             const auto & cell = kv_self.cells[i]; | ||||
|             data_ctx.write(&cell.pos, sizeof(cell.pos)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Iterate and write all the keys first, each row is a cell | ||||
|     // Get whole range at a time | ||||
|     std::vector<uint8_t> tmp_buf; | ||||
|     for (int il = 0; il < (int)n_layer; ++il) { | ||||
|         // Write key type | ||||
|         const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; | ||||
|         data_ctx.write(&k_type_i, sizeof(k_type_i)); | ||||
|  | ||||
|         // Write row size of key | ||||
|         const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); | ||||
|         data_ctx.write(&k_size_row, sizeof(k_size_row)); | ||||
|  | ||||
|         // Read each range of cells of k_size length each into tmp_buf and write out | ||||
|         for (const auto & range : cell_ranges) { | ||||
|             const size_t range_size = range.second - range.first; | ||||
|             tmp_buf.resize(range_size * k_size_row); | ||||
|             ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row); | ||||
|             data_ctx.write(tmp_buf.data(), tmp_buf.size()); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // For the values, they are transposed, so we also need the element size and get the element ranges from each row | ||||
|     const uint32_t kv_size = kv_self.size; | ||||
|     for (int il = 0; il < (int)n_layer; ++il) { | ||||
|         // Write value type | ||||
|         const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; | ||||
|         data_ctx.write(&v_type_i, sizeof(v_type_i)); | ||||
|  | ||||
|         // Write element size | ||||
|         const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); | ||||
|         data_ctx.write(&v_size_el, sizeof(v_size_el)); | ||||
|  | ||||
|         // For each row, we get the element values of each cell | ||||
|         for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { | ||||
|             // Read each range of cells of v_size_el length each into tmp_buf and write out | ||||
|             for (const auto & range : cell_ranges) { | ||||
|                 const size_t range_size = range.second - range.first; | ||||
|                 const size_t src_offset = (range.first + j * kv_size) * v_size_el; | ||||
|                 tmp_buf.resize(range_size * v_size_el); | ||||
|                 ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); | ||||
|                 data_ctx.write(tmp_buf.data(), tmp_buf.size()); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return data_ctx.get_size_written(); | ||||
| } | ||||
|  | ||||
| size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) { | ||||
|     llama_data_buffer_context data_ctx(dst); | ||||
|     return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); | ||||
| } | ||||
|  | ||||
| size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { | ||||
|     auto & kv_self = ctx->kv_self; | ||||
|     GGML_ASSERT(!kv_self.recurrent); // not implemented | ||||
|  | ||||
|     // Wipe the slot | ||||
|     llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); | ||||
|  | ||||
|     const uint8_t * inp = src; | ||||
|  | ||||
|     // Read size of size_t | ||||
|     uint32_t size_t_size; | ||||
|     memcpy(&size_t_size, inp, sizeof(size_t_size)); | ||||
|     inp += sizeof(size_t_size); | ||||
|     if (size_t_size != sizeof(size_t)) { | ||||
|         LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__); | ||||
|         return 0; | ||||
|     } | ||||
|  | ||||
|     // Read the cell count | ||||
|     uint32_t cell_count; | ||||
|     memcpy(&cell_count, inp, sizeof(cell_count)); | ||||
|     inp += sizeof(cell_count); | ||||
|  | ||||
|     // Read the layer count | ||||
|     uint32_t n_layer_ref; | ||||
|     memcpy(&n_layer_ref, inp, sizeof(n_layer_ref)); | ||||
|     inp += sizeof(n_layer_ref); | ||||
|  | ||||
|     // Read n_embd_v_gqa | ||||
|     uint32_t n_embd_v_gqa_ref; | ||||
|     memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref)); | ||||
|     inp += sizeof(n_embd_v_gqa_ref); | ||||
|  | ||||
|     // Sanity check model compatibility | ||||
|     const auto & hparams = ctx->model.hparams; | ||||
|     const uint32_t n_layer = hparams.n_layer; | ||||
|     const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); | ||||
|     const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); | ||||
|     if (n_layer != n_layer_ref) { | ||||
|         LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); | ||||
|         return 0; | ||||
|     } | ||||
|     if (n_embd_v_gqa != n_embd_v_gqa_ref) { | ||||
|         LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref); | ||||
|         return 0; | ||||
|     } | ||||
|  | ||||
|     // Allocate the new cells for the slot | ||||
|     if (cell_count) { | ||||
|         llama_batch batch = llama_batch_init(cell_count, 0, 1); | ||||
|         batch.n_tokens = cell_count; | ||||
|         for (uint32_t i = 0; i < cell_count; ++i) { | ||||
|             llama_pos pos; | ||||
|             memcpy(&pos, inp, sizeof(pos)); | ||||
|             inp += sizeof(pos); | ||||
|  | ||||
|             batch.pos[i] = pos; | ||||
|             batch.n_seq_id[i] = 1; | ||||
|             batch.seq_id[i][0] = dest_seq_id; | ||||
|         } | ||||
|         if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||||
|             llama_batch_free(batch); | ||||
|             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); | ||||
|             return 0; | ||||
|         } | ||||
|  | ||||
|         // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) | ||||
|         // Assume that this is one contiguous block of cells | ||||
|         GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); | ||||
|         GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); | ||||
|         GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); | ||||
|         GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); | ||||
|         GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); | ||||
|  | ||||
|         // Cleanup | ||||
|         llama_batch_free(batch); | ||||
|     } | ||||
|  | ||||
|     const uint32_t kv_size = kv_self.size; | ||||
|     const uint32_t kv_head = kv_self.head; | ||||
|  | ||||
|     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo | ||||
|     for (int il = 0; il < (int)n_layer; ++il) { | ||||
|         // Read type of key | ||||
|         int32_t k_type_i_ref; | ||||
|         memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref)); | ||||
|         inp += sizeof(k_type_i_ref); | ||||
|         const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; | ||||
|         if (k_type_i != k_type_i_ref) { | ||||
|             llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); | ||||
|             LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); | ||||
|             return 0; | ||||
|         } | ||||
|  | ||||
|         // Read row size of key | ||||
|         size_t k_size_row_ref; | ||||
|         memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); | ||||
|         inp += sizeof(k_size_row_ref); | ||||
|         const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); | ||||
|         if (k_size_row != k_size_row_ref) { | ||||
|             llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); | ||||
|             LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il); | ||||
|             return 0; | ||||
|         } | ||||
|  | ||||
|         if (cell_count) { | ||||
|             // Read and set the keys for the whole cell range | ||||
|             ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row); | ||||
|             inp += cell_count * k_size_row; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // For each layer, read the values for each cell (transposed) | ||||
|     for (int il = 0; il < (int)n_layer; ++il) { | ||||
|         // Read type of value | ||||
|         int32_t v_type_i_ref; | ||||
|         memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); | ||||
|         inp += sizeof(v_type_i_ref); | ||||
|         const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; | ||||
|         if (v_type_i != v_type_i_ref) { | ||||
|             llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); | ||||
|             LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); | ||||
|             return 0; | ||||
|         } | ||||
|  | ||||
|         // Read element size of value | ||||
|         size_t v_size_el_ref; | ||||
|         memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); | ||||
|         inp += sizeof(v_size_el_ref); | ||||
|         const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); | ||||
|         if (v_size_el != v_size_el_ref) { | ||||
|             llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); | ||||
|             LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); | ||||
|             return 0; | ||||
|         } | ||||
|  | ||||
|         if (cell_count) { | ||||
|             // For each row in the transposed matrix, read the values for the whole cell range | ||||
|             for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { | ||||
|                 const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; | ||||
|                 ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); | ||||
|                 inp += cell_count * v_size_el; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     const size_t nread = inp - src; | ||||
|     return nread; | ||||
| } | ||||
|  | ||||
| static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { | ||||
|     llama_file file(filepath, "wb"); | ||||
|  | ||||
|     file.write_u32(LLAMA_STATE_SEQ_MAGIC); | ||||
|     file.write_u32(LLAMA_STATE_SEQ_VERSION); | ||||
|  | ||||
|     // save the prompt | ||||
|     file.write_u32((uint32_t)n_token_count); | ||||
|     file.write_raw(tokens, sizeof(llama_token) * n_token_count); | ||||
|  | ||||
|     // save the context state using stream saving | ||||
|     llama_data_file_context data_ctx(&file); | ||||
|     llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); | ||||
|  | ||||
|     const size_t res = file.tell(); | ||||
|     GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { | ||||
|     llama_file file(filepath, "rb"); | ||||
|  | ||||
|     // version checks | ||||
|     { | ||||
|         const uint32_t magic   = file.read_u32(); | ||||
|         const uint32_t version = file.read_u32(); | ||||
|  | ||||
|         if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { | ||||
|             LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); | ||||
|             return 0; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // load the prompt | ||||
|     { | ||||
|         const uint32_t n_token_count = file.read_u32(); | ||||
|  | ||||
|         if (n_token_count > n_token_capacity) { | ||||
|             LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); | ||||
|             return 0; | ||||
|         } | ||||
|  | ||||
|         file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); | ||||
|         *n_token_count_out = n_token_count; | ||||
|     } | ||||
|  | ||||
|     // restore the context state | ||||
|     { | ||||
|         const size_t state_size = file.size - file.tell(); | ||||
|         std::vector<uint8_t> state_data(state_size); | ||||
|         file.read_raw(state_data.data(), state_size); | ||||
|         const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id); | ||||
|         if (!nread) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); | ||||
|             return 0; | ||||
|         } | ||||
|         GGML_ASSERT(nread <= state_size); | ||||
|         GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); | ||||
|     } | ||||
|  | ||||
|     return file.tell(); | ||||
| } | ||||
|  | ||||
| size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { | ||||
|     try { | ||||
|         return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); | ||||
|     } catch (const std::exception & err) { | ||||
|         LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what()); | ||||
|         return 0; | ||||
|     } | ||||
| } | ||||
|  | ||||
| size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { | ||||
|     try { | ||||
|         return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); | ||||
|     } catch (const std::exception & err) { | ||||
|         LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); | ||||
|         return 0; | ||||
|     } | ||||
| } | ||||
|  | ||||
| void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { | ||||
|     ctx->cparams.n_threads       = n_threads; | ||||
|     ctx->cparams.n_threads_batch = n_threads_batch; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jan Boon
					Jan Boon