mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : minimize size used for state save/load (#4820)
* examples : save-load-state: save only required state * llama : only reserve n_vocab * n_batch at most for logits llama_decode asserts that only n_batch tokens are passed each call, and n_ctx is expected to be bigger than n_batch. * llama : always reserve n_vocab * n_batch for logits llama_context de-serialization breaks if the contexts have differing capacity for logits and llama_decode will at maximum resize to n_vocab * n_batch. * llama : only save and restore used logits for batch sizes of 512 this reduces save state in the best case by around 62 MB, which can be a lot if planning to save on each message to allow regenerating messages. * llama : use ostringstream and istringstream for save and load * llama : serialize rng into minimum amount of space required * llama : break session version due to serialization changes
This commit is contained in:
		
							
								
								
									
										53
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -9379,12 +9379,8 @@ struct llama_context * llama_new_context_with_model( | ||||
|                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); | ||||
|         } | ||||
|  | ||||
|         // resized during inference | ||||
|         if (params.logits_all) { | ||||
|             ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab); | ||||
|         } else { | ||||
|             ctx->logits.reserve(hparams.n_vocab); | ||||
|         } | ||||
|         // resized during inference, reserve maximum | ||||
|         ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); | ||||
|  | ||||
|         if (params.embedding){ | ||||
|             ctx->embedding.resize(hparams.n_embd); | ||||
| @@ -9731,8 +9727,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) { | ||||
|     // for reference, std::mt19937(1337) serializes to 6701 bytes. | ||||
|     const size_t s_rng_size        = sizeof(size_t); | ||||
|     const size_t s_rng             = LLAMA_MAX_RNG_STATE; | ||||
|     const size_t s_logits_capacity = sizeof(size_t); | ||||
|     const size_t s_logits_size     = sizeof(size_t); | ||||
|     // assume worst case for logits although only currently set ones are serialized | ||||
|     const size_t s_logits          = ctx->logits.capacity() * sizeof(float); | ||||
|     const size_t s_embedding_size  = sizeof(size_t); | ||||
|     const size_t s_embedding       = ctx->embedding.size() * sizeof(float); | ||||
| @@ -9743,7 +9739,6 @@ size_t llama_get_state_size(const struct llama_context * ctx) { | ||||
|     const size_t s_total = ( | ||||
|         + s_rng_size | ||||
|         + s_rng | ||||
|         + s_logits_capacity | ||||
|         + s_logits_size | ||||
|         + s_logits | ||||
|         + s_embedding_size | ||||
| @@ -9812,37 +9807,27 @@ struct llama_data_file_context : llama_data_context { | ||||
| static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { | ||||
|     // copy rng | ||||
|     { | ||||
|         std::stringstream rng_ss; | ||||
|         std::ostringstream rng_ss; | ||||
|         rng_ss << ctx->rng; | ||||
|  | ||||
|         const size_t rng_size = rng_ss.str().size(); | ||||
|         char rng_buf[LLAMA_MAX_RNG_STATE]; | ||||
|         const std::string & rng_str = rng_ss.str(); | ||||
|         const size_t        rng_size = rng_str.size(); | ||||
|  | ||||
|         memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); | ||||
|         memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); | ||||
|         GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); | ||||
|  | ||||
|         data_ctx->write(&rng_size,   sizeof(rng_size)); | ||||
|         data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); | ||||
|         data_ctx->write(&rng_size,      sizeof(rng_size)); | ||||
|         data_ctx->write(rng_str.data(), rng_size); | ||||
|     } | ||||
|  | ||||
|     // copy logits | ||||
|     { | ||||
|         const size_t logits_cap  = ctx->logits.capacity(); | ||||
|         const size_t logits_size = ctx->logits.size(); | ||||
|  | ||||
|         data_ctx->write(&logits_cap,  sizeof(logits_cap)); | ||||
|         data_ctx->write(&logits_size, sizeof(logits_size)); | ||||
|  | ||||
|         if (logits_size) { | ||||
|             data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); | ||||
|         } | ||||
|  | ||||
|         // If there is a gap between the size and the capacity, write padding | ||||
|         size_t padding_size = (logits_cap - logits_size) * sizeof(float); | ||||
|         if (padding_size > 0) { | ||||
|             std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros | ||||
|             data_ctx->write(padding.data(), padding_size); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // copy embeddings | ||||
| @@ -9925,13 +9910,13 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { | ||||
|     // set rng | ||||
|     { | ||||
|         size_t rng_size; | ||||
|         char   rng_buf[LLAMA_MAX_RNG_STATE]; | ||||
|         memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); | ||||
|  | ||||
|         memcpy(&rng_size,   inp, sizeof(rng_size));    inp += sizeof(rng_size); | ||||
|         memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; | ||||
|         GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); | ||||
|  | ||||
|         std::stringstream rng_ss; | ||||
|         rng_ss.str(std::string(&rng_buf[0], rng_size)); | ||||
|         std::string rng_str((char *)inp, rng_size); inp += rng_size; | ||||
|  | ||||
|         std::istringstream rng_ss(rng_str); | ||||
|         rng_ss >> ctx->rng; | ||||
|  | ||||
|         GGML_ASSERT(!rng_ss.fail()); | ||||
| @@ -9939,20 +9924,18 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { | ||||
|  | ||||
|     // set logits | ||||
|     { | ||||
|         size_t logits_cap; | ||||
|         size_t logits_size; | ||||
|  | ||||
|         memcpy(&logits_cap,  inp, sizeof(logits_cap));  inp += sizeof(logits_cap); | ||||
|         memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); | ||||
|  | ||||
|         GGML_ASSERT(ctx->logits.capacity() == logits_cap); | ||||
|         GGML_ASSERT(ctx->logits.capacity() >= logits_size); | ||||
|  | ||||
|         if (logits_size) { | ||||
|             ctx->logits.resize(logits_size); | ||||
|             memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); | ||||
|         } | ||||
|  | ||||
|         inp += logits_cap * sizeof(float); | ||||
|             memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); | ||||
|             inp += logits_size * sizeof(float); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // set embeddings | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 David Friehs
					David Friehs