mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : improve context checkpoint logic (#16440)
This commit is contained in:
		| @@ -861,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: | |||||||
| bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { | bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { | ||||||
|     if (dest_seq_id != -1) { |     if (dest_seq_id != -1) { | ||||||
|         // single sequence |         // single sequence | ||||||
|  |  | ||||||
|         seq_rm(dest_seq_id, -1, -1); |         seq_rm(dest_seq_id, -1, -1); | ||||||
|  |  | ||||||
|  |         if (cell_count == 0) { | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         llama_batch_allocr balloc(hparams.n_pos_per_embd()); |         llama_batch_allocr balloc(hparams.n_pos_per_embd()); | ||||||
|  |  | ||||||
|         llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); |         llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); | ||||||
|   | |||||||
| @@ -3676,6 +3676,20 @@ struct server_context { | |||||||
|                         alora_disabled_id = enabled_loras[0]; |                         alora_disabled_id = enabled_loras[0]; | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|  |                     bool do_checkpoint = params_base.n_ctx_checkpoints > 0; | ||||||
|  |  | ||||||
|  |                     // make a checkpoint of the parts of the memory that cannot be rolled back. | ||||||
|  |                     // checkpoints are created only if: | ||||||
|  |                     // - the model uses SWA and we are not using `swa_full` | ||||||
|  |                     // - the model architecture is marked as recurrent or hybrid | ||||||
|  |                     // | ||||||
|  |                     // TODO: try to make this conditional on the context or the memory module, instead of the model type | ||||||
|  |                     do_checkpoint = do_checkpoint && ( | ||||||
|  |                             llama_model_is_recurrent(model) || | ||||||
|  |                             llama_model_is_hybrid(model) || | ||||||
|  |                             (llama_model_n_swa(model) > 0 && !params_base.swa_full) | ||||||
|  |                             ); | ||||||
|  |  | ||||||
|                     // add prompt tokens for processing in the current batch |                     // add prompt tokens for processing in the current batch | ||||||
|                     while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { |                     while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { | ||||||
|                         // get next token to process |                         // get next token to process | ||||||
| @@ -3700,6 +3714,11 @@ struct server_context { | |||||||
|  |  | ||||||
|                         slot.n_prompt_tokens_processed++; |                         slot.n_prompt_tokens_processed++; | ||||||
|                         slot.n_past++; |                         slot.n_past++; | ||||||
|  |  | ||||||
|  |                         // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. | ||||||
|  |                         if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) { | ||||||
|  |                             break; | ||||||
|  |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); |                     // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); | ||||||
| @@ -3730,6 +3749,39 @@ struct server_context { | |||||||
|                         slot.i_batch   = batch.n_tokens - 1; |                         slot.i_batch   = batch.n_tokens - 1; | ||||||
|  |  | ||||||
|                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); |                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); | ||||||
|  |  | ||||||
|  |                         const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); | ||||||
|  |                         const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); | ||||||
|  |  | ||||||
|  |                         // no need for empty or small checkpoints | ||||||
|  |                         do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); | ||||||
|  |  | ||||||
|  |                         // no need to create checkpoints that are too close together | ||||||
|  |                         do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64); | ||||||
|  |  | ||||||
|  |                         if (do_checkpoint) { | ||||||
|  |                             while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { | ||||||
|  |                                 // make room for the new checkpoint, if needed | ||||||
|  |                                 const auto & cur = slot.ctx_checkpoints.front(); | ||||||
|  |                                 SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", | ||||||
|  |                                         cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||||||
|  |  | ||||||
|  |                                 slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); | ||||||
|  |                             } | ||||||
|  |  | ||||||
|  |                             const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); | ||||||
|  |  | ||||||
|  |                             auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ | ||||||
|  |                                 /*.pos_min = */ pos_min, | ||||||
|  |                                 /*.pos_max = */ pos_max, | ||||||
|  |                                 /*.data    = */ std::vector<uint8_t>(checkpoint_size), | ||||||
|  |                             }); | ||||||
|  |  | ||||||
|  |                             llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); | ||||||
|  |  | ||||||
|  |                             SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", | ||||||
|  |                                     (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); | ||||||
|  |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
| @@ -3853,40 +3905,6 @@ struct server_context { | |||||||
|  |  | ||||||
|                     // prompt evaluated for next-token prediction |                     // prompt evaluated for next-token prediction | ||||||
|                     slot.state = SLOT_STATE_GENERATING; |                     slot.state = SLOT_STATE_GENERATING; | ||||||
|  |  | ||||||
|                     // make a checkpoint of the parts of the memory that cannot be rolled back. |  | ||||||
|                     // checkpoints are created only if: |  | ||||||
|                     // - the model uses SWA and we are not using `swa_full` |  | ||||||
|                     // - the model architecture is marked as recurrent or hybrid |  | ||||||
|                     // |  | ||||||
|                     // TODO: try to make this conditional on the context or the memory module, instead of the model type |  | ||||||
|                     const bool do_checkpoint = |  | ||||||
|                         (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || |  | ||||||
|                         (llama_model_n_swa(model) > 0 && !params_base.swa_full); |  | ||||||
|  |  | ||||||
|                     if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { |  | ||||||
|                         while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { |  | ||||||
|                             // make room for the new checkpoint, if needed |  | ||||||
|                             const auto & cur = slot.ctx_checkpoints.front(); |  | ||||||
|                             SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", |  | ||||||
|                                     cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); |  | ||||||
|  |  | ||||||
|                             slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); |  | ||||||
|                         } |  | ||||||
|  |  | ||||||
|                         const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); |  | ||||||
|  |  | ||||||
|                         auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ |  | ||||||
|                             /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), |  | ||||||
|                             /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), |  | ||||||
|                             /*.data    = */ std::vector<uint8_t>(checkpoint_size), |  | ||||||
|                         }); |  | ||||||
|  |  | ||||||
|                         llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); |  | ||||||
|  |  | ||||||
|                         SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", |  | ||||||
|                                 (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); |  | ||||||
|                     } |  | ||||||
|                 } else if (slot.state != SLOT_STATE_GENERATING) { |                 } else if (slot.state != SLOT_STATE_GENERATING) { | ||||||
|                     continue; // continue loop of slots |                     continue; // continue loop of slots | ||||||
|                 } |                 } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov