mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : add api for getting/setting the complete state: rng, logits, embedding and kv_cache (#1105)
* reserve correct size for logits * add functions to get and set the whole llama state: including rng, logits, embedding and kv_cache * remove unused variables * remove trailing whitespace * fix comment
This commit is contained in:
		
							
								
								
									
										122
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										122
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -27,6 +27,7 @@ | ||||
| #include <thread> | ||||
| #include <atomic> | ||||
| #include <mutex> | ||||
| #include <sstream> | ||||
|  | ||||
| #define LLAMA_USE_SCRATCH | ||||
| #define LLAMA_MAX_SCRATCH_BUFFERS 16 | ||||
| @@ -1787,7 +1788,7 @@ struct llama_context * llama_init_from_file( | ||||
|         if (params.logits_all) { | ||||
|             ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); | ||||
|         } else { | ||||
|             ctx->logits.reserve(hparams.n_ctx); | ||||
|             ctx->logits.reserve(hparams.n_vocab); | ||||
|         } | ||||
|  | ||||
|         if (params.embedding){ | ||||
| @@ -2252,3 +2253,122 @@ const char * llama_print_system_info(void) { | ||||
| std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) { | ||||
|     return ctx->model.tensors_by_name; | ||||
| } | ||||
|  | ||||
| // Returns the size of the state | ||||
| size_t llama_get_state_size(struct llama_context * ctx) { | ||||
|     const size_t s_bool = sizeof(int32_t); | ||||
|     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. | ||||
|     // for reference, std::mt19937(1337) serializes to 6701 bytes. | ||||
|     const size_t s_rng_size = sizeof(size_t); | ||||
|     const size_t s_rng = 64*1024; | ||||
|     const size_t s_logits_capacity = sizeof(size_t); | ||||
|     const size_t s_logits_size = sizeof(size_t); | ||||
|     const size_t s_logits = ctx->logits.capacity() * sizeof(float); | ||||
|     const size_t s_embedding_size = sizeof(size_t); | ||||
|     const size_t s_embedding = ctx->embedding.size() * sizeof(float); | ||||
|     const size_t s_kv_size = sizeof(size_t); | ||||
|     const size_t s_kv_ntok = sizeof(int); | ||||
|     const size_t s_kv = llama_get_kv_cache_size(ctx); | ||||
|     const size_t s_total = ( | ||||
|         + s_rng_size | ||||
|         + s_rng | ||||
|         + s_logits_capacity | ||||
|         + s_logits_size | ||||
|         + s_logits | ||||
|         + s_embedding_size | ||||
|         + s_embedding | ||||
|         + s_kv_size | ||||
|         + s_kv_ntok | ||||
|         + s_kv | ||||
|     ); | ||||
|     return s_total; | ||||
| } | ||||
|  | ||||
| // Copies the state to the specified destination address | ||||
| size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { | ||||
|     std::stringstream rng_ss; | ||||
|     rng_ss << ctx->rng; | ||||
|     const size_t rng_size = rng_ss.str().size(); | ||||
|     char rng_buf[64*1024]; | ||||
|     memset(&rng_buf[0], 0, 64*1024); | ||||
|     memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); | ||||
|     const size_t logits_capacity = ctx->logits.capacity(); | ||||
|     const size_t logits_size = ctx->logits.size(); | ||||
|     const size_t embedding_size = ctx->embedding.size(); | ||||
|     const size_t kv_size = llama_get_kv_cache_size(ctx); | ||||
|     const int kv_ntok = llama_get_kv_cache_token_count(ctx); | ||||
|  | ||||
|     uint8_t * out = dest; | ||||
|     memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t); | ||||
|     memcpy(out, &rng_buf[0], 64*1024); out += 64*1024; | ||||
|     memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); | ||||
|     memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); | ||||
|     if (logits_size) { | ||||
|         memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); | ||||
|     } | ||||
|     out += logits_capacity * sizeof(float); | ||||
|     memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); | ||||
|     if (embedding_size) { | ||||
|         memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); | ||||
|     } | ||||
|     memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t); | ||||
|     memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int); | ||||
|     if (kv_size) { | ||||
|         memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size; | ||||
|     } | ||||
|     const size_t written = out - dest; | ||||
|     const size_t expected = llama_get_state_size(ctx); | ||||
|     LLAMA_ASSERT(written == expected); | ||||
|     return written; | ||||
| } | ||||
|  | ||||
| // Sets the state reading from the specified source address | ||||
| size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { | ||||
|     size_t rng_size; | ||||
|     char rng_buf[64*1024]; | ||||
|     std::stringstream rng_ss; | ||||
|  | ||||
|     const uint8_t * in = src; | ||||
|     memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t); | ||||
|     memcpy(&rng_buf[0], in, 64*1024); in += 64*1024; | ||||
|     rng_ss.str(std::string(&rng_buf[0], rng_size)); | ||||
|     rng_ss >> ctx->rng; | ||||
|     LLAMA_ASSERT(rng_ss.fail() == false); | ||||
|  | ||||
|     size_t logits_capacity; | ||||
|     size_t logits_size; | ||||
|     size_t embedding_size; | ||||
|     size_t kv_size; | ||||
|     int kv_ntok; | ||||
|  | ||||
|     memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t); | ||||
|     memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t); | ||||
|     LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); | ||||
|     if (logits_size) { | ||||
|         ctx->logits.resize(logits_size); | ||||
|         memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); | ||||
|     } | ||||
|     in += logits_capacity * sizeof(float); | ||||
|     memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); | ||||
|     LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); | ||||
|     if (embedding_size) { | ||||
|         memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); | ||||
|         in += embedding_size * sizeof(float); | ||||
|     } | ||||
|     memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); | ||||
|     memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int); | ||||
|     if (kv_size) { | ||||
|         LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); | ||||
|         void * k_data = ctx->model.kv_self.k->data; // remember data pointers | ||||
|         void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy | ||||
|         memcpy(ctx->model.kv_self.buf.addr, in, kv_size); | ||||
|         ctx->model.kv_self.k->data = k_data; // restore correct data pointers | ||||
|         ctx->model.kv_self.v->data = v_data; | ||||
|         in += kv_size; | ||||
|     } | ||||
|     ctx->model.kv_self.n = kv_ntok; | ||||
|     const size_t nread = in - src; | ||||
|     const size_t expected = llama_get_state_size(ctx); | ||||
|     LLAMA_ASSERT(nread == expected); | ||||
|     return nread; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 xaedes
					xaedes