mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : refactor get / set state + remove redundant kv cache API (#1143)
This commit is contained in:
		
							
								
								
									
										323
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										323
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| // Returns the KV cache that will contain the context for the |  | ||||||
| // ongoing prediction with the model. |  | ||||||
| const uint8_t * llama_get_kv_cache(struct llama_context * ctx) { |  | ||||||
|     return ctx->model.kv_self.buf.addr; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Returns the size of the KV cache |  | ||||||
| size_t llama_get_kv_cache_size(struct llama_context * ctx) { |  | ||||||
|     return ctx->model.kv_self.buf.size; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| int llama_get_kv_cache_token_count(struct llama_context * ctx) { | int llama_get_kv_cache_token_count(struct llama_context * ctx) { | ||||||
|     return ctx->model.kv_self.n; |     return ctx->model.kv_self.n; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Sets the KV cache containing the current context for the model | #define LLAMA_MAX_RNG_STATE 64*1024 | ||||||
| void llama_set_kv_cache( |  | ||||||
|         struct llama_context * ctx, | // Returns the size of the state | ||||||
|                const uint8_t * kv_cache, | size_t llama_get_state_size(struct llama_context * ctx) { | ||||||
|                       size_t   n_size, |     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. | ||||||
|                          int   n_token_count) { |     // for reference, std::mt19937(1337) serializes to 6701 bytes. | ||||||
|     // Make sure we have the same kv cache setup |     const size_t s_rng_size        = sizeof(size_t); | ||||||
|     LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size); |     const size_t s_rng             = LLAMA_MAX_RNG_STATE; | ||||||
|     void * k_data = ctx->model.kv_self.k->data; // remember data pointers |     const size_t s_logits_capacity = sizeof(size_t); | ||||||
|     void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy |     const size_t s_logits_size     = sizeof(size_t); | ||||||
|     memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size); |     const size_t s_logits          = ctx->logits.capacity() * sizeof(float); | ||||||
|     ctx->model.kv_self.k->data = k_data; // restore correct data pointers |     const size_t s_embedding_size  = sizeof(size_t); | ||||||
|     ctx->model.kv_self.v->data = v_data; |     const size_t s_embedding       = ctx->embedding.size() * sizeof(float); | ||||||
|     ctx->model.kv_self.n = n_token_count; |     const size_t s_kv_size         = sizeof(size_t); | ||||||
|  |     const size_t s_kv_ntok         = sizeof(int); | ||||||
|  |     const size_t s_kv              = ctx->model.kv_self.buf.size; | ||||||
|  |  | ||||||
|  |     const size_t s_total = ( | ||||||
|  |         + s_rng_size | ||||||
|  |         + s_rng | ||||||
|  |         + s_logits_capacity | ||||||
|  |         + s_logits_size | ||||||
|  |         + s_logits | ||||||
|  |         + s_embedding_size | ||||||
|  |         + s_embedding | ||||||
|  |         + s_kv_size | ||||||
|  |         + s_kv_ntok | ||||||
|  |         + s_kv | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     return s_total; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Copies the state to the specified destination address | ||||||
|  | size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { | ||||||
|  |     uint8_t * out = dest; | ||||||
|  |  | ||||||
|  |     // copy rng | ||||||
|  |     { | ||||||
|  |         std::stringstream rng_ss; | ||||||
|  |         rng_ss << ctx->rng; | ||||||
|  |  | ||||||
|  |         const size_t rng_size = rng_ss.str().size(); | ||||||
|  |         char rng_buf[LLAMA_MAX_RNG_STATE]; | ||||||
|  |  | ||||||
|  |         memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); | ||||||
|  |         memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); | ||||||
|  |  | ||||||
|  |         memcpy(out, &rng_size,   sizeof(rng_size));    out += sizeof(rng_size); | ||||||
|  |         memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // copy logits | ||||||
|  |     { | ||||||
|  |         const size_t logits_cap  = ctx->logits.capacity(); | ||||||
|  |         const size_t logits_size = ctx->logits.size(); | ||||||
|  |  | ||||||
|  |         memcpy(out, &logits_cap,  sizeof(logits_cap));  out += sizeof(logits_cap); | ||||||
|  |         memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); | ||||||
|  |  | ||||||
|  |         if (logits_size) { | ||||||
|  |             memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         out += logits_cap * sizeof(float); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // copy embeddings | ||||||
|  |     { | ||||||
|  |         const size_t embedding_size = ctx->embedding.size(); | ||||||
|  |  | ||||||
|  |         memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size); | ||||||
|  |  | ||||||
|  |         if (embedding_size) { | ||||||
|  |             memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); | ||||||
|  |             out += embedding_size * sizeof(float); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // copy kv cache | ||||||
|  |     { | ||||||
|  |         const size_t kv_size = ctx->model.kv_self.buf.size; | ||||||
|  |         const int    kv_ntok = llama_get_kv_cache_token_count(ctx); | ||||||
|  |  | ||||||
|  |         memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); | ||||||
|  |         memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); | ||||||
|  |  | ||||||
|  |         if (kv_size) { | ||||||
|  |             memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const size_t written  = out - dest; | ||||||
|  |     const size_t expected = llama_get_state_size(ctx); | ||||||
|  |  | ||||||
|  |     LLAMA_ASSERT(written == expected); | ||||||
|  |  | ||||||
|  |     return written; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Sets the state reading from the specified source address | ||||||
|  | size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { | ||||||
|  |     const uint8_t * in = src; | ||||||
|  |  | ||||||
|  |     // set rng | ||||||
|  |     { | ||||||
|  |         size_t rng_size; | ||||||
|  |         char   rng_buf[LLAMA_MAX_RNG_STATE]; | ||||||
|  |  | ||||||
|  |         memcpy(&rng_size,   in, sizeof(rng_size));    in += sizeof(rng_size); | ||||||
|  |         memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE; | ||||||
|  |  | ||||||
|  |         std::stringstream rng_ss; | ||||||
|  |         rng_ss.str(std::string(&rng_buf[0], rng_size)); | ||||||
|  |         rng_ss >> ctx->rng; | ||||||
|  |  | ||||||
|  |         LLAMA_ASSERT(rng_ss.fail() == false); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // set logits | ||||||
|  |     { | ||||||
|  |         size_t logits_cap; | ||||||
|  |         size_t logits_size; | ||||||
|  |  | ||||||
|  |         memcpy(&logits_cap,  in, sizeof(logits_cap));  in += sizeof(logits_cap); | ||||||
|  |         memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size); | ||||||
|  |  | ||||||
|  |         LLAMA_ASSERT(ctx->logits.capacity() == logits_cap); | ||||||
|  |  | ||||||
|  |         if (logits_size) { | ||||||
|  |             ctx->logits.resize(logits_size); | ||||||
|  |             memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         in += logits_cap * sizeof(float); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // set embeddings | ||||||
|  |     { | ||||||
|  |         size_t embedding_size; | ||||||
|  |  | ||||||
|  |         memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size); | ||||||
|  |  | ||||||
|  |         LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); | ||||||
|  |  | ||||||
|  |         if (embedding_size) { | ||||||
|  |             memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); | ||||||
|  |             in += embedding_size * sizeof(float); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // set kv cache | ||||||
|  |     { | ||||||
|  |         size_t kv_size; | ||||||
|  |         int kv_ntok; | ||||||
|  |  | ||||||
|  |         memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); | ||||||
|  |         memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); | ||||||
|  |  | ||||||
|  |         if (kv_size) { | ||||||
|  |             LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); | ||||||
|  |  | ||||||
|  |             void * k_data = ctx->model.kv_self.k->data; // remember data pointers | ||||||
|  |             void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy | ||||||
|  |  | ||||||
|  |             memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size; | ||||||
|  |  | ||||||
|  |             ctx->model.kv_self.k->data = k_data; // restore correct data pointers | ||||||
|  |             ctx->model.kv_self.v->data = v_data; | ||||||
|  |  | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         ctx->model.kv_self.n = kv_ntok; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const size_t nread    = in - src; | ||||||
|  |     const size_t expected = llama_get_state_size(ctx); | ||||||
|  |  | ||||||
|  |     LLAMA_ASSERT(nread == expected); | ||||||
|  |  | ||||||
|  |     return nread; | ||||||
| } | } | ||||||
|  |  | ||||||
| int llama_eval( | int llama_eval( | ||||||
| @@ -2256,120 +2412,3 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te | |||||||
|     return ctx->model.tensors_by_name; |     return ctx->model.tensors_by_name; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Returns the size of the state |  | ||||||
| size_t llama_get_state_size(struct llama_context * ctx) { |  | ||||||
|     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. |  | ||||||
|     // for reference, std::mt19937(1337) serializes to 6701 bytes. |  | ||||||
|     const size_t s_rng_size = sizeof(size_t); |  | ||||||
|     const size_t s_rng = 64*1024; |  | ||||||
|     const size_t s_logits_capacity = sizeof(size_t); |  | ||||||
|     const size_t s_logits_size = sizeof(size_t); |  | ||||||
|     const size_t s_logits = ctx->logits.capacity() * sizeof(float); |  | ||||||
|     const size_t s_embedding_size = sizeof(size_t); |  | ||||||
|     const size_t s_embedding = ctx->embedding.size() * sizeof(float); |  | ||||||
|     const size_t s_kv_size = sizeof(size_t); |  | ||||||
|     const size_t s_kv_ntok = sizeof(int); |  | ||||||
|     const size_t s_kv = llama_get_kv_cache_size(ctx); |  | ||||||
|     const size_t s_total = ( |  | ||||||
|         + s_rng_size |  | ||||||
|         + s_rng |  | ||||||
|         + s_logits_capacity |  | ||||||
|         + s_logits_size |  | ||||||
|         + s_logits |  | ||||||
|         + s_embedding_size |  | ||||||
|         + s_embedding |  | ||||||
|         + s_kv_size |  | ||||||
|         + s_kv_ntok |  | ||||||
|         + s_kv |  | ||||||
|     ); |  | ||||||
|     return s_total; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Copies the state to the specified destination address |  | ||||||
| size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { |  | ||||||
|     std::stringstream rng_ss; |  | ||||||
|     rng_ss << ctx->rng; |  | ||||||
|     const size_t rng_size = rng_ss.str().size(); |  | ||||||
|     char rng_buf[64*1024]; |  | ||||||
|     memset(&rng_buf[0], 0, 64*1024); |  | ||||||
|     memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); |  | ||||||
|     const size_t logits_capacity = ctx->logits.capacity(); |  | ||||||
|     const size_t logits_size = ctx->logits.size(); |  | ||||||
|     const size_t embedding_size = ctx->embedding.size(); |  | ||||||
|     const size_t kv_size = llama_get_kv_cache_size(ctx); |  | ||||||
|     const int kv_ntok = llama_get_kv_cache_token_count(ctx); |  | ||||||
|  |  | ||||||
|     uint8_t * out = dest; |  | ||||||
|     memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t); |  | ||||||
|     memcpy(out, &rng_buf[0], 64*1024); out += 64*1024; |  | ||||||
|     memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); |  | ||||||
|     memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); |  | ||||||
|     if (logits_size) { |  | ||||||
|         memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); |  | ||||||
|     } |  | ||||||
|     out += logits_capacity * sizeof(float); |  | ||||||
|     memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); |  | ||||||
|     if (embedding_size) { |  | ||||||
|         memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); |  | ||||||
|     } |  | ||||||
|     memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t); |  | ||||||
|     memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int); |  | ||||||
|     if (kv_size) { |  | ||||||
|         memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size; |  | ||||||
|     } |  | ||||||
|     const size_t written = out - dest; |  | ||||||
|     const size_t expected = llama_get_state_size(ctx); |  | ||||||
|     LLAMA_ASSERT(written == expected); |  | ||||||
|     return written; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Sets the state reading from the specified source address |  | ||||||
| size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { |  | ||||||
|     size_t rng_size; |  | ||||||
|     char rng_buf[64*1024]; |  | ||||||
|     std::stringstream rng_ss; |  | ||||||
|  |  | ||||||
|     const uint8_t * in = src; |  | ||||||
|     memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t); |  | ||||||
|     memcpy(&rng_buf[0], in, 64*1024); in += 64*1024; |  | ||||||
|     rng_ss.str(std::string(&rng_buf[0], rng_size)); |  | ||||||
|     rng_ss >> ctx->rng; |  | ||||||
|     LLAMA_ASSERT(rng_ss.fail() == false); |  | ||||||
|  |  | ||||||
|     size_t logits_capacity; |  | ||||||
|     size_t logits_size; |  | ||||||
|     size_t embedding_size; |  | ||||||
|     size_t kv_size; |  | ||||||
|     int kv_ntok; |  | ||||||
|  |  | ||||||
|     memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t); |  | ||||||
|     memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t); |  | ||||||
|     LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); |  | ||||||
|     if (logits_size) { |  | ||||||
|         ctx->logits.resize(logits_size); |  | ||||||
|         memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); |  | ||||||
|     } |  | ||||||
|     in += logits_capacity * sizeof(float); |  | ||||||
|     memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); |  | ||||||
|     LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); |  | ||||||
|     if (embedding_size) { |  | ||||||
|         memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); |  | ||||||
|         in += embedding_size * sizeof(float); |  | ||||||
|     } |  | ||||||
|     memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); |  | ||||||
|     memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int); |  | ||||||
|     if (kv_size) { |  | ||||||
|         LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); |  | ||||||
|         void * k_data = ctx->model.kv_self.k->data; // remember data pointers |  | ||||||
|         void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy |  | ||||||
|         memcpy(ctx->model.kv_self.buf.addr, in, kv_size); |  | ||||||
|         ctx->model.kv_self.k->data = k_data; // restore correct data pointers |  | ||||||
|         ctx->model.kv_self.v->data = v_data; |  | ||||||
|         in += kv_size; |  | ||||||
|     } |  | ||||||
|     ctx->model.kv_self.n = kv_ntok; |  | ||||||
|     const size_t nread = in - src; |  | ||||||
|     const size_t expected = llama_get_state_size(ctx); |  | ||||||
|     LLAMA_ASSERT(nread == expected); |  | ||||||
|     return nread; |  | ||||||
| } |  | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								llama.h
									
									
									
									
									
								
							| @@ -112,23 +112,9 @@ extern "C" { | |||||||
|                       const char * path_base_model, |                       const char * path_base_model, | ||||||
|                              int   n_threads); |                              int   n_threads); | ||||||
|  |  | ||||||
|     // Returns the KV cache that will contain the context for the |  | ||||||
|     // ongoing prediction with the model. |  | ||||||
|     LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx); |  | ||||||
|  |  | ||||||
|     // Returns the size of the KV cache |  | ||||||
|     LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx); |  | ||||||
|  |  | ||||||
|     // Returns the number of tokens in the KV cache |     // Returns the number of tokens in the KV cache | ||||||
|     LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); |     LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); | ||||||
|  |  | ||||||
|     // Sets the KV cache containing the current context for the model |  | ||||||
|     LLAMA_API void llama_set_kv_cache( |  | ||||||
|             struct llama_context * ctx, |  | ||||||
|                    const uint8_t * kv_cache, |  | ||||||
|                           size_t   n_size, |  | ||||||
|                              int   n_token_count); |  | ||||||
|  |  | ||||||
|     // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) |     // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) | ||||||
|     LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); |     LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov