mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : fix session saving/loading (#3400)
* llama : fix session saving/loading * llama : temp fix for clearing "future" tokens from the KV cache * llama : fix handling of "future" tokens when loading sessions * llama : fix comments for llama_kv_cache API
This commit is contained in:
		
							
								
								
									
										134
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										134
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1283,8 +1283,8 @@ static bool llama_kv_cache_init( | ||||
| // find an empty slot of size "n_tokens" in the cache | ||||
| // updates the cache head | ||||
| static bool llama_kv_cache_find_slot( | ||||
|              struct llama_kv_cache & cache, | ||||
|           const struct llama_batch & batch) { | ||||
|            struct llama_kv_cache & cache, | ||||
|         const struct llama_batch & batch) { | ||||
|     const uint32_t n_ctx    = cache.size; | ||||
|     const uint32_t n_tokens = batch.n_tokens; | ||||
|  | ||||
| @@ -1352,10 +1352,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_rm( | ||||
|              struct llama_kv_cache & cache, | ||||
|                       llama_seq_id   seq_id, | ||||
|                          llama_pos   p0, | ||||
|                          llama_pos   p1) { | ||||
|         struct llama_kv_cache & cache, | ||||
|                  llama_seq_id   seq_id, | ||||
|                     llama_pos   p0, | ||||
|                     llama_pos   p1) { | ||||
|     if (p0 < 0) p0 = 0; | ||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cache.size; ++i) { | ||||
|         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||
|             cache.cells[i].seq_id.erase(seq_id); | ||||
| @@ -1367,11 +1370,14 @@ static void llama_kv_cache_seq_rm( | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_cp( | ||||
|              struct llama_kv_cache & cache, | ||||
|                       llama_seq_id   seq_id_src, | ||||
|                       llama_seq_id   seq_id_dst, | ||||
|                          llama_pos   p0, | ||||
|                          llama_pos   p1) { | ||||
|         struct llama_kv_cache & cache, | ||||
|                  llama_seq_id   seq_id_src, | ||||
|                  llama_seq_id   seq_id_dst, | ||||
|                     llama_pos   p0, | ||||
|                     llama_pos   p1) { | ||||
|     if (p0 < 0) p0 = 0; | ||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cache.size; ++i) { | ||||
|         if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||
|             cache.cells[i].seq_id.insert(seq_id_dst); | ||||
| @@ -1389,11 +1395,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_shift( | ||||
|              struct llama_kv_cache & cache, | ||||
|                       llama_seq_id   seq_id, | ||||
|                          llama_pos   p0, | ||||
|                          llama_pos   p1, | ||||
|                          llama_pos   delta) { | ||||
|         struct llama_kv_cache & cache, | ||||
|                  llama_seq_id   seq_id, | ||||
|                     llama_pos   p0, | ||||
|                     llama_pos   p1, | ||||
|                     llama_pos   delta) { | ||||
|     if (p0 < 0) p0 = 0; | ||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cache.size; ++i) { | ||||
|         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||
|             cache.cells[i].pos += delta; | ||||
| @@ -7209,16 +7218,6 @@ struct llama_data_file_context : llama_data_context { | ||||
|  * | ||||
| */ | ||||
| static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { | ||||
|     // TODO: does not support multi-sequence states | ||||
|     { | ||||
|         const auto & kv_self = ctx->kv_self; | ||||
|         for (uint32_t i = 0; i < kv_self.head; ++i) { | ||||
|             GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i); | ||||
|             GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1); | ||||
|             GGML_ASSERT(kv_self.cells[i].has_seq_id(0)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // copy rng | ||||
|     { | ||||
|         std::stringstream rng_ss; | ||||
| @@ -7271,36 +7270,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat | ||||
|         const auto & hparams = ctx->model.hparams; | ||||
|         const auto & cparams = ctx->cparams; | ||||
|  | ||||
|         const int    n_layer = hparams.n_layer; | ||||
|         const int    n_embd  = hparams.n_embd_gqa(); | ||||
|         const int    n_ctx   = cparams.n_ctx; | ||||
|         const auto   n_layer = hparams.n_layer; | ||||
|         const auto   n_embd  = hparams.n_embd_gqa(); | ||||
|         const auto   n_ctx   = cparams.n_ctx; | ||||
|  | ||||
|         const size_t kv_size = kv_self.buf.size; | ||||
|         const int    kv_ntok = kv_self.head; | ||||
|         const size_t   kv_buf_size = kv_self.buf.size; | ||||
|         const uint32_t kv_head     = kv_self.head; | ||||
|         const uint32_t kv_size     = kv_self.size; | ||||
|  | ||||
|         data_ctx->write(&kv_size, sizeof(kv_size)); | ||||
|         data_ctx->write(&kv_ntok, sizeof(kv_ntok)); | ||||
|         data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); | ||||
|         data_ctx->write(&kv_head,     sizeof(kv_head)); | ||||
|         data_ctx->write(&kv_size,     sizeof(kv_size)); | ||||
|  | ||||
|         if (kv_size) { | ||||
|         if (kv_buf_size) { | ||||
|             const size_t elt_size = ggml_element_size(kv_self.k); | ||||
|  | ||||
|             ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); | ||||
|             ggml_cgraph gf{}; | ||||
|  | ||||
|             ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); | ||||
|             ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); | ||||
|             std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0); | ||||
|             kout3d->data = kout3d_data.data(); | ||||
|  | ||||
|             ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); | ||||
|             ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer); | ||||
|             std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0); | ||||
|             vout3d->data = vout3d_data.data(); | ||||
|  | ||||
|             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, | ||||
|                 n_embd, kv_ntok, n_layer, | ||||
|                 n_embd, kv_head, n_layer, | ||||
|                 elt_size*n_embd, elt_size*n_embd*n_ctx, 0); | ||||
|  | ||||
|             ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, | ||||
|                 kv_ntok, n_embd, n_layer, | ||||
|                 kv_head, n_embd, n_layer, | ||||
|                 elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); | ||||
|  | ||||
|             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); | ||||
| @@ -7314,6 +7315,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat | ||||
|             data_ctx->write(kout3d_data.data(), kout3d_data.size()); | ||||
|             data_ctx->write(vout3d_data.data(), vout3d_data.size()); | ||||
|         } | ||||
|  | ||||
|         for (uint32_t i = 0; i < kv_size; ++i) { | ||||
|             const auto & cell = kv_self.cells[i]; | ||||
|  | ||||
|             const llama_pos pos         = cell.pos; | ||||
|             const size_t    seq_id_size = cell.seq_id.size(); | ||||
|  | ||||
|             data_ctx->write(&pos,         sizeof(pos)); | ||||
|             data_ctx->write(&seq_id_size, sizeof(seq_id_size)); | ||||
|  | ||||
|             for (auto seq_id : cell.seq_id) { | ||||
|                 data_ctx->write(&seq_id, sizeof(seq_id)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -7385,34 +7400,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { | ||||
|         const int    n_embd  = hparams.n_embd_gqa(); | ||||
|         const int    n_ctx   = cparams.n_ctx; | ||||
|  | ||||
|         size_t kv_size; | ||||
|         int kv_ntok; | ||||
|         size_t   kv_buf_size; | ||||
|         uint32_t kv_head; | ||||
|         uint32_t kv_size; | ||||
|  | ||||
|         memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); | ||||
|         memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok); | ||||
|         memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); | ||||
|         memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head); | ||||
|         memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size); | ||||
|  | ||||
|         if (kv_size) { | ||||
|             GGML_ASSERT(kv_self.buf.size == kv_size); | ||||
|         if (kv_buf_size) { | ||||
|             GGML_ASSERT(kv_self.buf.size == kv_buf_size); | ||||
|  | ||||
|             const size_t elt_size = ggml_element_size(kv_self.k); | ||||
|  | ||||
|             ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); | ||||
|             ggml_cgraph gf{}; | ||||
|  | ||||
|             ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); | ||||
|             ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); | ||||
|             kin3d->data = (void *) inp; | ||||
|             inp += ggml_nbytes(kin3d); | ||||
|  | ||||
|             ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); | ||||
|             ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer); | ||||
|             vin3d->data = (void *) inp; | ||||
|             inp += ggml_nbytes(vin3d); | ||||
|  | ||||
|             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, | ||||
|                 n_embd, kv_ntok, n_layer, | ||||
|                 n_embd, kv_head, n_layer, | ||||
|                 elt_size*n_embd, elt_size*n_embd*n_ctx, 0); | ||||
|  | ||||
|             ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, | ||||
|                 kv_ntok, n_embd, n_layer, | ||||
|                 kv_head, n_embd, n_layer, | ||||
|                 elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); | ||||
|  | ||||
|             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); | ||||
| @@ -7422,8 +7439,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { | ||||
|             ggml_free(cpy_ctx); | ||||
|         } | ||||
|  | ||||
|         ctx->kv_self.head = kv_ntok; | ||||
|         ctx->kv_self.head = kv_head; | ||||
|         ctx->kv_self.size = kv_size; | ||||
|  | ||||
|         ctx->kv_self.cells.resize(kv_size); | ||||
|  | ||||
|         for (uint32_t i = 0; i < kv_size; ++i) { | ||||
|             llama_pos pos; | ||||
|             size_t    seq_id_size; | ||||
|  | ||||
|             memcpy(&pos,         inp, sizeof(pos));         inp += sizeof(pos); | ||||
|             memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); | ||||
|  | ||||
|             ctx->kv_self.cells[i].pos = pos; | ||||
|  | ||||
|             llama_seq_id seq_id; | ||||
|  | ||||
|             for (size_t j = 0; j < seq_id_size; ++j) { | ||||
|                 memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); | ||||
|                 ctx->kv_self.cells[i].seq_id.insert(seq_id); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     const size_t nread    = inp - src; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov