mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Stream save llama context data to file instead of allocating entire buffer upfront (#2488)
* added stream saving context data to file to avoid allocating unnecessary amounts of memory * generalised copying state data to file or buffer * added comments explaining how copy_state_data works * fixed trailing whitespaces * fixed save load state example * updated save load state to use public function in llama.cpp * - restored breakage of the llama_copy_state_data API - moved new logic for copying llama state data to internal function * fixed function declaration order * restored save load state example * fixed whitepace * removed unused llama-util.h include * Apply suggestions from code review Co-authored-by: slaren <slarengh@gmail.com> * Apply code review suggestions Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
		
							
								
								
									
										40
									
								
								llama-util.h
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								llama-util.h
									
									
									
									
									
								
							| @@ -149,6 +149,46 @@ struct llama_file { | |||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | // llama_context_data | ||||||
|  | struct llama_data_context { | ||||||
|  |     virtual void write(const void * src, size_t size) = 0; | ||||||
|  |     virtual size_t get_size_written() = 0; | ||||||
|  |     virtual ~llama_data_context() = default; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct llama_data_buffer_context : llama_data_context { | ||||||
|  |     uint8_t* ptr; | ||||||
|  |     size_t size_written = 0; | ||||||
|  |  | ||||||
|  |     llama_data_buffer_context(uint8_t * p) : ptr(p) {} | ||||||
|  |  | ||||||
|  |     void write(const void * src, size_t size) override { | ||||||
|  |         memcpy(ptr, src, size); | ||||||
|  |         ptr += size; | ||||||
|  |         size_written += size; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     size_t get_size_written() override { | ||||||
|  |         return size_written; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct llama_data_file_context : llama_data_context { | ||||||
|  |     llama_file* file; | ||||||
|  |     size_t size_written = 0; | ||||||
|  |  | ||||||
|  |     llama_data_file_context(llama_file * f) : file(f) {} | ||||||
|  |  | ||||||
|  |     void write(const void * src, size_t size) override { | ||||||
|  |         file->write_raw(src, size); | ||||||
|  |         size_written += size; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     size_t get_size_written() override { | ||||||
|  |         return size_written; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
| #if defined(_WIN32) | #if defined(_WIN32) | ||||||
| static std::string llama_format_win_err(DWORD err) { | static std::string llama_format_win_err(DWORD err) { | ||||||
|     LPSTR buf; |     LPSTR buf; | ||||||
|   | |||||||
							
								
								
									
										79
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										79
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -3743,10 +3743,20 @@ size_t llama_get_state_size(const struct llama_context * ctx) { | |||||||
|     return s_total; |     return s_total; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Copies the state to the specified destination address | /** copy state data into either a buffer or file depending on the passed in context | ||||||
| size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { |  * | ||||||
|     uint8_t * out = dst; |  * file context: | ||||||
|  |  * llama_file file("/path", "wb"); | ||||||
|  |  * llama_data_file_context data_ctx(&file); | ||||||
|  |  * llama_copy_state_data(ctx, &data_ctx); | ||||||
|  |  * | ||||||
|  |  * buffer context: | ||||||
|  |  * std::vector<uint8_t> buf(max_size, 0); | ||||||
|  |  * llama_data_buffer_context data_ctx(&buf.data()); | ||||||
|  |  * llama_copy_state_data(ctx, &data_ctx); | ||||||
|  |  * | ||||||
|  | */ | ||||||
|  | void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { | ||||||
|     // copy rng |     // copy rng | ||||||
|     { |     { | ||||||
|         std::stringstream rng_ss; |         std::stringstream rng_ss; | ||||||
| @@ -3758,8 +3768,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | |||||||
|         memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); |         memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); | ||||||
|         memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); |         memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); | ||||||
|  |  | ||||||
|         memcpy(out, &rng_size,   sizeof(rng_size));    out += sizeof(rng_size); |         data_ctx->write(&rng_size,   sizeof(rng_size)); | ||||||
|         memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; |         data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // copy logits |     // copy logits | ||||||
| @@ -3767,25 +3777,29 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | |||||||
|         const size_t logits_cap  = ctx->logits.capacity(); |         const size_t logits_cap  = ctx->logits.capacity(); | ||||||
|         const size_t logits_size = ctx->logits.size(); |         const size_t logits_size = ctx->logits.size(); | ||||||
|  |  | ||||||
|         memcpy(out, &logits_cap,  sizeof(logits_cap));  out += sizeof(logits_cap); |         data_ctx->write(&logits_cap,  sizeof(logits_cap)); | ||||||
|         memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); |         data_ctx->write(&logits_size, sizeof(logits_size)); | ||||||
|  |  | ||||||
|         if (logits_size) { |         if (logits_size) { | ||||||
|             memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); |             data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         out += logits_cap * sizeof(float); |         // If there is a gap between the size and the capacity, write padding | ||||||
|  |         size_t padding_size = (logits_cap - logits_size) * sizeof(float); | ||||||
|  |         if (padding_size > 0) { | ||||||
|  |             std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros | ||||||
|  |             data_ctx->write(padding.data(), padding_size); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // copy embeddings |     // copy embeddings | ||||||
|     { |     { | ||||||
|         const size_t embedding_size = ctx->embedding.size(); |         const size_t embedding_size = ctx->embedding.size(); | ||||||
|  |  | ||||||
|         memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size); |         data_ctx->write(&embedding_size, sizeof(embedding_size)); | ||||||
|  |  | ||||||
|         if (embedding_size) { |         if (embedding_size) { | ||||||
|             memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); |             data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); | ||||||
|             out += embedding_size * sizeof(float); |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -3800,8 +3814,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | |||||||
|         const size_t kv_size = kv_self.buf.size; |         const size_t kv_size = kv_self.buf.size; | ||||||
|         const int    kv_ntok = llama_get_kv_cache_token_count(ctx); |         const int    kv_ntok = llama_get_kv_cache_token_count(ctx); | ||||||
|  |  | ||||||
|         memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); |         data_ctx->write(&kv_size, sizeof(kv_size)); | ||||||
|         memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); |         data_ctx->write(&kv_ntok, sizeof(kv_ntok)); | ||||||
|  |  | ||||||
|         if (kv_size) { |         if (kv_size) { | ||||||
|             const size_t elt_size = ggml_element_size(kv_self.k); |             const size_t elt_size = ggml_element_size(kv_self.k); | ||||||
| @@ -3810,12 +3824,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | |||||||
|             ggml_cgraph gf{}; |             ggml_cgraph gf{}; | ||||||
|  |  | ||||||
|             ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); |             ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); | ||||||
|             kout3d->data = out; |             std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0); | ||||||
|             out += ggml_nbytes(kout3d); |             kout3d->data = kout3d_data.data(); | ||||||
|  |  | ||||||
|             ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); |             ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); | ||||||
|             vout3d->data = out; |             std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0); | ||||||
|             out += ggml_nbytes(vout3d); |             vout3d->data = vout3d_data.data(); | ||||||
|  |  | ||||||
|             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, |             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, | ||||||
|                 n_embd, kv_ntok, n_layer, |                 n_embd, kv_ntok, n_layer, | ||||||
| @@ -3830,15 +3844,20 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | |||||||
|             ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); |             ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); | ||||||
|  |  | ||||||
|             ggml_free(cpy_ctx); |             ggml_free(cpy_ctx); | ||||||
|  |  | ||||||
|  |             // our data is now in the kout3d_data and vout3d_data buffers | ||||||
|  |             // write them to file | ||||||
|  |             data_ctx->write(kout3d_data.data(), kout3d_data.size()); | ||||||
|  |             data_ctx->write(vout3d_data.data(), vout3d_data.size()); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|     const size_t written  = out - dst; | size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { | ||||||
|     const size_t max_size = llama_get_state_size(ctx); |     llama_data_buffer_context data_ctx(dst); | ||||||
|  |     llama_copy_state_data_internal(ctx, &data_ctx); | ||||||
|  |  | ||||||
|     LLAMA_ASSERT(written <= max_size); |     return data_ctx.get_size_written(); | ||||||
|  |  | ||||||
|     return written; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Sets the state reading from the specified source address | // Sets the state reading from the specified source address | ||||||
| @@ -4023,15 +4042,9 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi | |||||||
|     file.write_u32((uint32_t) n_token_count); |     file.write_u32((uint32_t) n_token_count); | ||||||
|     file.write_raw(tokens, sizeof(llama_token) * n_token_count); |     file.write_raw(tokens, sizeof(llama_token) * n_token_count); | ||||||
|  |  | ||||||
|     // save the context state |     // save the context state using stream saving | ||||||
|     { |     llama_data_file_context data_ctx(&file); | ||||||
|         const size_t n_state_size_max = llama_get_state_size(ctx); |     llama_copy_state_data_internal(ctx, &data_ctx); | ||||||
|  |  | ||||||
|         std::vector<uint8_t> state_data(n_state_size_max); |  | ||||||
|         const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data()); |  | ||||||
|  |  | ||||||
|         file.write_raw(state_data.data(), n_state_size_cur); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 l3utterfly
					l3utterfly