mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : fix embeddings (#5796)
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
This commit is contained in:
		
							
								
								
									
										357
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										357
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1665,7 +1665,7 @@ struct llama_hparams { | ||||
| }; | ||||
|  | ||||
| struct llama_cparams { | ||||
|     uint32_t n_ctx;       // context size used during inference | ||||
|     uint32_t n_ctx;           // context size used during inference | ||||
|     uint32_t n_batch; | ||||
|     uint32_t n_threads;       // number of threads to use for generation | ||||
|     uint32_t n_threads_batch; // number of threads to use for batch processing | ||||
| @@ -1682,7 +1682,9 @@ struct llama_cparams { | ||||
|     float yarn_beta_slow; | ||||
|     float defrag_thold; | ||||
|  | ||||
|     bool embeddings; | ||||
|     bool offload_kqv; | ||||
|  | ||||
|     enum llama_pooling_type pooling_type; | ||||
|  | ||||
|     ggml_backend_sched_eval_callback cb_eval; | ||||
| @@ -1972,7 +1974,7 @@ struct llama_context { | ||||
|     int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) | ||||
|     int32_t n_eval   = 0; // number of eval calls | ||||
|  | ||||
|     // decode output (2-dimensional array: [n_tokens][n_vocab]) | ||||
|     // logits output (2-dimensional array: [n_tokens][n_vocab]) | ||||
|     std::vector<float> logits; | ||||
| #ifndef NDEBUG | ||||
|     // guard against access to unset logits | ||||
| @@ -1980,8 +1982,13 @@ struct llama_context { | ||||
| #endif | ||||
|     bool logits_all = false; | ||||
|  | ||||
|     // input embedding (1-dimensional array: [n_embd]) | ||||
|     std::vector<float> embedding; | ||||
|     // embeddings output (2-dimensional array: [n_tokens][n_embd]) | ||||
|     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE | ||||
|     std::vector<float> embd; | ||||
|  | ||||
|     // sequence embeddings output (map of [n_embd] vectors) | ||||
|     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE | ||||
|     std::map<llama_seq_id, std::vector<float>> embd_seq; | ||||
|  | ||||
|     // memory buffers used to evaluate the model | ||||
|     std::vector<uint8_t> buf_compute_meta; | ||||
| @@ -5092,6 +5099,7 @@ static struct ggml_tensor * llm_build_kv( | ||||
|     llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); | ||||
|  | ||||
|     struct ggml_tensor * cur; | ||||
|  | ||||
|     cur  = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, | ||||
|             q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); | ||||
|     cb(cur, "kqv_out", il); | ||||
| @@ -6085,6 +6093,7 @@ struct llm_build_context { | ||||
|  | ||||
|         const int64_t n_embd_head = hparams.n_embd_head_v; | ||||
|         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa(); | ||||
|  | ||||
|         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); | ||||
|  | ||||
|         struct ggml_tensor * cur; | ||||
| @@ -6092,9 +6101,10 @@ struct llm_build_context { | ||||
|  | ||||
|         // get input vectors with right size | ||||
|         const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type); | ||||
|         struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); | ||||
|  | ||||
|         struct ggml_tensor * inp_pos  = ggml_view_1d(ctx0, lctx.inp_pos,  n_tokens, 0); | ||||
|         struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0); | ||||
|         struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0); | ||||
|         struct ggml_tensor * inp_cls  = ggml_view_1d(ctx0, lctx.inp_cls,  n_tokens, 0); | ||||
|  | ||||
|         // construct input embeddings (token, type, position) | ||||
|         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); | ||||
| @@ -6112,39 +6122,38 @@ struct llm_build_context { | ||||
|         cb(inpL, "inp_norm", -1); | ||||
|  | ||||
|         // KQ_mask (mask for 1 head, it will be broadcasted to all heads) | ||||
|         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); | ||||
|         cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens] | ||||
|         struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0)); | ||||
|         cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens] | ||||
|  | ||||
|         // iterate layers | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|             struct ggml_tensor * cur = inpL; | ||||
|  | ||||
|             struct ggml_tensor * Qcur; | ||||
|             struct ggml_tensor * Kcur; | ||||
|             struct ggml_tensor * Vcur; | ||||
|  | ||||
|             // self-attention | ||||
|             if (model.arch == LLM_ARCH_BERT) { | ||||
|                 struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); | ||||
|                 Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); | ||||
|                 cb(Qcur, "Qcur", il); | ||||
|  | ||||
|                 struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); | ||||
|                 Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
|  | ||||
|                 struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); | ||||
|                 Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); | ||||
|                 cb(Vcur, "Vcur", il); | ||||
|  | ||||
|                 // seems like we just need to do this for Q? | ||||
|                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); | ||||
|  | ||||
|                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, | ||||
|                         model.layers[il].wo, model.layers[il].bo, | ||||
|                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); | ||||
|                 cb(cur, "kqv_out", il); | ||||
|                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens); | ||||
|                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); | ||||
|             } else { | ||||
|                 // compute Q and K and RoPE them | ||||
|                 cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); | ||||
|                 cb(cur, "wqkv", il); | ||||
|  | ||||
|                 struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); | ||||
|                 struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); | ||||
|                 struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); | ||||
|                 Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); | ||||
|                 Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); | ||||
|                 Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); | ||||
|  | ||||
|                 cb(Qcur, "Qcur", il); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
| @@ -6163,13 +6172,41 @@ struct llm_build_context { | ||||
|                     ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
|  | ||||
|                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, | ||||
|                         model.layers[il].wo, model.layers[il].bo, | ||||
|                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); | ||||
|                 cb(cur, "kqv_out", il); | ||||
|             } | ||||
|  | ||||
|             struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3); | ||||
|             struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); | ||||
|  | ||||
|             struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); | ||||
|             cb(kq, "kq", il); | ||||
|  | ||||
|             kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); | ||||
|             cb(kq, "kq_soft_max_ext", il); | ||||
|  | ||||
|             struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); | ||||
|             cb(v, "v", il); | ||||
|  | ||||
|             struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); | ||||
|             cb(kqv, "kqv", il); | ||||
|  | ||||
|             struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); | ||||
|             cb(kqv_merged, "kqv_merged", il); | ||||
|  | ||||
|             cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); | ||||
|             cb(cur, "kqv_merged_cont", il); | ||||
|  | ||||
|             ggml_build_forward_expand(gf, cur); | ||||
|  | ||||
|             cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); | ||||
|             if (model.layers[il].bo) { | ||||
|                 cb(cur, "kqv_wo", il); | ||||
|             } | ||||
|  | ||||
|             if (model.layers[il].bo) { | ||||
|                 cur = ggml_add(ctx0, cur, model.layers[il].bo); | ||||
|             } | ||||
|             cb(cur, "kqv_out", il); | ||||
|  | ||||
|             // re-add the layer input | ||||
|             cur = ggml_add(ctx0, cur, inpL); | ||||
|  | ||||
| @@ -6209,16 +6246,29 @@ struct llm_build_context { | ||||
|  | ||||
|         // final output | ||||
|         cur = inpL; | ||||
|         cb(cur, "result_embd", -1); | ||||
|  | ||||
|         // pooling layer | ||||
|         if (pooling_type == LLAMA_POOLING_TYPE_MEAN) { | ||||
|             cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); | ||||
|         } else if (pooling_type == LLAMA_POOLING_TYPE_CLS) { | ||||
|             cur = ggml_get_rows(ctx0, cur, inp_cls); | ||||
|         } else { | ||||
|             GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type"); | ||||
|         switch (pooling_type) { | ||||
|             case LLAMA_POOLING_TYPE_NONE: | ||||
|                 { | ||||
|                     // nop | ||||
|                 } break; | ||||
|             case LLAMA_POOLING_TYPE_MEAN: | ||||
|                 { | ||||
|                     cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); | ||||
|                     cb(cur, "result_embd_pooled", -1); | ||||
|                 } break; | ||||
|             case LLAMA_POOLING_TYPE_CLS: | ||||
|                 { | ||||
|                     cur = ggml_get_rows(ctx0, cur, inp_cls); | ||||
|                     cb(cur, "result_embd_pooled", -1); | ||||
|                 } break; | ||||
|             case LLAMA_POOLING_TYPE_UNSPECIFIED: | ||||
|                 { | ||||
|                     GGML_ASSERT(false && "Invalid pooling type"); | ||||
|                 } break; | ||||
|         } | ||||
|         cb(cur, "result_embd", -1); | ||||
|  | ||||
|         ggml_build_forward_expand(gf, cur); | ||||
|  | ||||
| @@ -7980,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|         ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); | ||||
|     } | ||||
|  | ||||
|     { | ||||
|     if (hparams.causal_attn) { | ||||
|         const int64_t n_kv     = kv_self.n; | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
| @@ -7995,16 +8045,40 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     float f; | ||||
|                     if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || | ||||
|                         (hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) { | ||||
|                     if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { | ||||
|                         f = -INFINITY; | ||||
|                     } else { | ||||
|                         f = 0; | ||||
|                         f = 0.0f; | ||||
|                     } | ||||
|                     data[h*(n_kv*n_tokens) + j*n_kv + i] = f; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used) | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
|         assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); | ||||
|  | ||||
|         float * data = (float *) lctx.inp_KQ_mask->data; | ||||
|  | ||||
|         for (int h = 0; h < 1; ++h) { | ||||
|             for (int j = 0; j < n_tokens; ++j) { | ||||
|                 const llama_seq_id seq_id = batch.seq_id[j][0]; | ||||
|  | ||||
|                 for (int i = 0; i < n_tokens; ++i) { | ||||
|                     float f = -INFINITY; | ||||
|                     for (int s = 0; s < batch.n_seq_id[i]; ++s) { | ||||
|                         if (batch.seq_id[i][s] == seq_id) { | ||||
|                             f = 0.0f; | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (hparams.need_kq_pos) { | ||||
| @@ -8023,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); | ||||
|         float * data = (float *) lctx.inp_mean->data; | ||||
|  | ||||
|         float * data = (float *) lctx.inp_mean->data; | ||||
|         memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); | ||||
|  | ||||
|         std::vector<uint64_t> sum(n_tokens, 0); | ||||
|         for (int i = 0; i < n_tokens; ++i) { | ||||
|             const llama_seq_id seq_id = batch.seq_id[i][0]; | ||||
|  | ||||
|             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); | ||||
|  | ||||
|             sum[seq_id] += 1; | ||||
|         } | ||||
|  | ||||
| @@ -8051,11 +8128,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); | ||||
|  | ||||
|         uint32_t * data = (uint32_t *) lctx.inp_cls->data; | ||||
|         memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); | ||||
|  | ||||
|         for (int i = 0; i < n_tokens; ++i) { | ||||
|             const llama_seq_id seq_id = batch.seq_id[i][0]; | ||||
|             const llama_pos pos = batch.pos[i]; | ||||
|             const llama_pos    pos    = batch.pos[i]; | ||||
|  | ||||
|             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); | ||||
|  | ||||
|             if (pos == 0) { | ||||
|                 data[seq_id] = i; | ||||
|             } | ||||
| @@ -8169,24 +8251,27 @@ static int llama_decode_internal( | ||||
|         batch.seq_id = seq_id_arr.data(); | ||||
|     } | ||||
|  | ||||
|     llama_kv_cache_update(&lctx); | ||||
|     // non-causal masks do not use the KV cache | ||||
|     if (hparams.causal_attn) { | ||||
|         llama_kv_cache_update(&lctx); | ||||
|  | ||||
|     // if we have enough unused cells before the current head -> | ||||
|     //   better to start searching from the beginning of the cache, hoping to fill it | ||||
|     if (kv_self.head > kv_self.used + 2*n_tokens) { | ||||
|         kv_self.head = 0; | ||||
|         // if we have enough unused cells before the current head -> | ||||
|         //   better to start searching from the beginning of the cache, hoping to fill it | ||||
|         if (kv_self.head > kv_self.used + 2*n_tokens) { | ||||
|             kv_self.head = 0; | ||||
|         } | ||||
|  | ||||
|         if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||||
|             return 1; | ||||
|         } | ||||
|  | ||||
|         // a heuristic, to avoid attending the full cache if it is not yet utilized | ||||
|         // after enough generations, the benefit from this heuristic disappears | ||||
|         // if we start defragmenting the cache, the benefit from this will be more important | ||||
|         kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); | ||||
|         //kv_self.n = llama_kv_cache_cell_max(kv_self); | ||||
|     } | ||||
|  | ||||
|     if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     // a heuristic, to avoid attending the full cache if it is not yet utilized | ||||
|     // after enough generations, the benefit from this heuristic disappears | ||||
|     // if we start defragmenting the cache, the benefit from this will be more important | ||||
|     kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); | ||||
|     //kv_self.n = llama_kv_cache_cell_max(kv_self); | ||||
|  | ||||
|     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); | ||||
|  | ||||
|     ggml_backend_sched_reset(lctx.sched); | ||||
| @@ -8195,20 +8280,26 @@ static int llama_decode_internal( | ||||
|     ggml_cgraph * gf = llama_build_graph(lctx, batch, false); | ||||
|  | ||||
|     // the output is always the last tensor in the graph | ||||
|     struct ggml_tensor * res        = gf->nodes[gf->n_nodes - 1]; | ||||
|     struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; | ||||
|     struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1]; | ||||
|     struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; | ||||
|  | ||||
|     if (strcmp(res->name, "result_output") == 0) { | ||||
|         // the embeddings could be the second to last tensor, or the third to last tensor | ||||
|         if (strcmp(embeddings->name, "result_norm") != 0) { | ||||
|             embeddings = gf->nodes[gf->n_nodes - 3]; | ||||
|             GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); | ||||
|         } | ||||
|     } else if (strcmp(res->name, "result_embd") == 0) { | ||||
|         embeddings = res; | ||||
|         res = nullptr; | ||||
|     if (!hparams.causal_attn) { | ||||
|         res = nullptr; // do not extract logits for embedding models such as BERT | ||||
|  | ||||
|         // token or sequence embeddings | ||||
|         embd = gf->nodes[gf->n_nodes - 1]; | ||||
|  | ||||
|         GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); | ||||
|     } else { | ||||
|         GGML_ASSERT(false); | ||||
|         if (strcmp(res->name, "result_output") == 0) { | ||||
|             // the token embeddings could be the second to last tensor, or the third to last tensor | ||||
|             if (strcmp(embd->name, "result_norm") != 0) { | ||||
|                 embd = gf->nodes[gf->n_nodes - 3]; | ||||
|                 GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); | ||||
|             } | ||||
|         } else { | ||||
|             GGML_ASSERT(false && "missing result_output tensor"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); | ||||
| @@ -8275,46 +8366,82 @@ static int llama_decode_internal( | ||||
|         logits_out.clear(); | ||||
| #endif | ||||
|  | ||||
|         ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res); | ||||
|         GGML_ASSERT(res_backend != nullptr); | ||||
|         ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res); | ||||
|         GGML_ASSERT(backend_res != nullptr); | ||||
|  | ||||
|         if (batch.logits) { | ||||
|             logits_out.resize(n_vocab * n_tokens); | ||||
|             for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|                 if (batch.logits[i] == 0) { | ||||
|                     continue; | ||||
|                 } | ||||
|                 ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); | ||||
|                 ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); | ||||
| #ifndef NDEBUG | ||||
|                 logits_valid[i] = true; | ||||
| #endif | ||||
|             } | ||||
|         } else if (lctx.logits_all) { | ||||
|             logits_out.resize(n_vocab * n_tokens); | ||||
|             ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); | ||||
|             ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); | ||||
| #ifndef NDEBUG | ||||
|             std::fill(logits_valid.begin(), logits_valid.end(), true); | ||||
| #endif | ||||
|         } else { | ||||
|             logits_out.resize(n_vocab); | ||||
|             ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); | ||||
|             ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); | ||||
| #ifndef NDEBUG | ||||
|             logits_valid[0] = true; | ||||
| #endif | ||||
|         } | ||||
|         ggml_backend_synchronize(res_backend); | ||||
|         ggml_backend_synchronize(backend_res); | ||||
|     } | ||||
|  | ||||
|     // extract embeddings | ||||
|     if (!lctx.embedding.empty()) { | ||||
|         auto & embedding_out = lctx.embedding; | ||||
|     if (cparams.embeddings && embd) { | ||||
|         ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd); | ||||
|         GGML_ASSERT(backend_embd != nullptr); | ||||
|  | ||||
|         const int64_t embd_pos  = res ? n_embd * (n_tokens-1) : 0; | ||||
|         const int64_t embd_size = res ? n_embd : n_embd * n_tokens; | ||||
|         switch (cparams.pooling_type) { | ||||
|             case LLAMA_POOLING_TYPE_NONE: | ||||
|                 { | ||||
|                     // extract token embeddings | ||||
|                     auto & embd_out = lctx.embd; | ||||
|  | ||||
|         embedding_out.resize(embd_size); | ||||
|         ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); | ||||
|         ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float)); | ||||
|         ggml_backend_synchronize(embeddings_backend); | ||||
|                     if (batch.logits) { | ||||
|                         embd_out.resize(n_embd * n_tokens); | ||||
|                         for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|                             if (batch.logits[i] == 0) { | ||||
|                                 continue; | ||||
|                             } | ||||
|  | ||||
|                             ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); | ||||
|                         } | ||||
|                     } | ||||
|                 } break; | ||||
|             case LLAMA_POOLING_TYPE_CLS: | ||||
|             case LLAMA_POOLING_TYPE_MEAN: | ||||
|                 { | ||||
|                     GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); | ||||
|  | ||||
|                     // extract sequence embeddings | ||||
|                     auto & embd_seq_out = lctx.embd_seq; | ||||
|                     embd_seq_out.clear(); | ||||
|  | ||||
|                     for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|                         const llama_seq_id seq_id = batch.seq_id[i][0]; | ||||
|                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { | ||||
|                             continue; | ||||
|                         } | ||||
|                         embd_seq_out[seq_id].resize(n_embd); | ||||
|                         ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); | ||||
|                     } | ||||
|                 } break; | ||||
|             case LLAMA_POOLING_TYPE_UNSPECIFIED: | ||||
|                 { | ||||
|                     GGML_ASSERT(false && "unknown pooling type"); | ||||
|                 } break; | ||||
|         } | ||||
|         ggml_backend_synchronize(backend_embd); | ||||
|     } | ||||
|  | ||||
|     // measure the performance only for the single-token evals | ||||
| @@ -8608,19 +8735,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { | ||||
|     GGML_ASSERT(llama_is_byte_token(vocab, id)); | ||||
|     const auto& token_data = vocab.id_to_token.at(id); | ||||
|     switch (llama_vocab_get_type(vocab)) { | ||||
|     case LLAMA_VOCAB_TYPE_SPM: { | ||||
|         auto buf = token_data.text.substr(3, 2); | ||||
|         return strtol(buf.c_str(), NULL, 16); | ||||
|     } | ||||
|     case LLAMA_VOCAB_TYPE_BPE: { | ||||
|         GGML_ASSERT(false); | ||||
|         return unicode_to_bytes_bpe(token_data.text); | ||||
|     } | ||||
|     case LLAMA_VOCAB_TYPE_WPM: { | ||||
|         GGML_ASSERT(false); | ||||
|     } | ||||
|     default: | ||||
|         GGML_ASSERT(false); | ||||
|         case LLAMA_VOCAB_TYPE_SPM: { | ||||
|             auto buf = token_data.text.substr(3, 2); | ||||
|             return strtol(buf.c_str(), NULL, 16); | ||||
|         } | ||||
|         case LLAMA_VOCAB_TYPE_BPE: { | ||||
|             GGML_ASSERT(false); | ||||
|             return unicode_to_bytes_bpe(token_data.text); | ||||
|         } | ||||
|         case LLAMA_VOCAB_TYPE_WPM: { | ||||
|             GGML_ASSERT(false); | ||||
|         } | ||||
|         default: | ||||
|             GGML_ASSERT(false); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -11864,7 +11991,7 @@ struct llama_context_params llama_context_default_params() { | ||||
|         /*.type_k                      =*/ GGML_TYPE_F16, | ||||
|         /*.type_v                      =*/ GGML_TYPE_F16, | ||||
|         /*.logits_all                  =*/ false, | ||||
|         /*.embedding                   =*/ false, | ||||
|         /*.embeddings                  =*/ false, | ||||
|         /*.offload_kqv                 =*/ true, | ||||
|         /*.abort_callback              =*/ nullptr, | ||||
|         /*.abort_callback_data         =*/ nullptr, | ||||
| @@ -12015,6 +12142,7 @@ struct llama_context * llama_new_context_with_model( | ||||
|     cparams.yarn_beta_fast   = params.yarn_beta_fast; | ||||
|     cparams.yarn_beta_slow   = params.yarn_beta_slow; | ||||
|     cparams.defrag_thold     = params.defrag_thold; | ||||
|     cparams.embeddings       = params.embeddings; | ||||
|     cparams.offload_kqv      = params.offload_kqv; | ||||
|     cparams.pooling_type     = params.pooling_type; | ||||
|  | ||||
| @@ -12192,8 +12320,8 @@ struct llama_context * llama_new_context_with_model( | ||||
|         // resized during inference, reserve maximum | ||||
|         ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); | ||||
|  | ||||
|         if (params.embedding) { | ||||
|             ctx->embedding.resize(hparams.n_embd); | ||||
|         if (params.embeddings) { | ||||
|             ctx->embd.reserve(hparams.n_embd*cparams.n_batch); | ||||
|         } | ||||
|  | ||||
|         // graph inputs | ||||
| @@ -12628,7 +12756,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { | ||||
|     // assume worst case for logits although only currently set ones are serialized | ||||
|     const size_t s_logits          = ctx->logits.capacity() * sizeof(float); | ||||
|     const size_t s_embedding_size  = sizeof(size_t); | ||||
|     const size_t s_embedding       = ctx->embedding.size() * sizeof(float); | ||||
|     const size_t s_embedding       = ctx->embd.capacity() * sizeof(float); | ||||
|     const size_t s_kv_buf_size     = sizeof(size_t); | ||||
|     const size_t s_kv_head         = sizeof(uint32_t); | ||||
|     const size_t s_kv_size         = sizeof(uint32_t); | ||||
| @@ -12737,12 +12865,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat | ||||
|  | ||||
|     // copy embeddings | ||||
|     { | ||||
|         const size_t embedding_size = ctx->embedding.size(); | ||||
|         const size_t embeddings_size = ctx->embd.size(); | ||||
|  | ||||
|         data_ctx->write(&embedding_size, sizeof(embedding_size)); | ||||
|         data_ctx->write(&embeddings_size, sizeof(embeddings_size)); | ||||
|  | ||||
|         if (embedding_size) { | ||||
|             data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); | ||||
|         if (embeddings_size) { | ||||
|             data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -12846,15 +12974,17 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { | ||||
|  | ||||
|     // set embeddings | ||||
|     { | ||||
|         size_t embedding_size; | ||||
|         size_t embeddings_size; | ||||
|  | ||||
|         memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); | ||||
|         memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size); | ||||
|  | ||||
|         GGML_ASSERT(ctx->embedding.capacity() == embedding_size); | ||||
|         GGML_ASSERT(ctx->embd.capacity() == embeddings_size); | ||||
|  | ||||
|         if (embedding_size) { | ||||
|             memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); | ||||
|             inp += embedding_size * sizeof(float); | ||||
|         if (embeddings_size) { | ||||
|             ctx->embd.resize(embeddings_size); | ||||
|  | ||||
|             memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float)); | ||||
|             inp += embeddings_size * sizeof(float); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -13104,11 +13234,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { | ||||
| } | ||||
|  | ||||
| float * llama_get_embeddings(struct llama_context * ctx) { | ||||
|     return ctx->embedding.data(); | ||||
|     return ctx->embd.data(); | ||||
| } | ||||
|  | ||||
| float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { | ||||
|     return ctx->embedding.data() + i*ctx->model.hparams.n_embd; | ||||
|     return ctx->embd.data() + i*ctx->model.hparams.n_embd; | ||||
| } | ||||
|  | ||||
| float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { | ||||
|     auto it = ctx->embd_seq.find(seq_id); | ||||
|     if (it == ctx->embd_seq.end()) { | ||||
|         return nullptr; | ||||
|     } | ||||
|  | ||||
|     return it->second.data(); | ||||
| } | ||||
|  | ||||
| const char * llama_token_get_text(const struct llama_model * model, llama_token token) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov