context : pass embeddings tensor from encoder to decoder

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-25 16:11:17 +02:00
parent e2b3294f2c
commit 4efe989886
2 changed files with 29 additions and 23 deletions

View File

@@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_
// llama_context_enc // llama_context_enc
// //
// TODO: avoid copy-paste of the entire encode() function
int llama_context_enc::encode(llama_batch & inp_batch) { int llama_context_enc::encode(llama_batch & inp_batch) {
if (inp_batch.n_tokens == 0) { if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
// overlap with device computation. // overlap with device computation.
ggml_backend_sched_reset(sched.get()); ggml_backend_sched_reset(sched.get());
cross->n_outputs = n_tokens; cross->t_embd = t_embd;
cross->embd_enc = embd;
// remember the sequence ids used during the encoding - needed for cross attention later // remember the sequence ids used during the encoding - needed for cross attention later
cross->seq_ids_enc.resize(n_tokens); cross->seq_ids_enc.resize(n_tokens);
@@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
void llama_context_dec::reserve() { void llama_context_dec::reserve() {
// simulate full KV cache // simulate full KV cache
cross->n_outputs = cparams.n_ubatch; cross->t_embd = nullptr;
LLAMA_LOG_DEBUG("%s: n_outputs = %u\n", __func__, cross->n_outputs);
llama_context_kv_self::reserve(); llama_context_kv_self::reserve();
} }
@@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
// call base functionality // call base functionality
llama_context_kv_self::input_set(ubatch); llama_context_kv_self::input_set(ubatch);
if (inp.cross_embd) { //if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
assert(inp.cross_embd->type == GGML_TYPE_F32); // assert(inp.cross_embd->type == GGML_TYPE_F32);
assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd); // assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd)); // ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
} //}
if (inp.cross_kq_mask) { if (inp.cross_kq_mask) {
const int64_t n_output_enc = cross->n_outputs; const int64_t n_enc = inp.cross_kq_mask->ne[0];
const int64_t n_tokens = ubatch.n_tokens; const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
@@ -4721,7 +4719,7 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_output_enc; ++i) { for (int i = 0; i < n_enc; ++i) {
float f = -INFINITY; float f = -INFINITY;
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[j][s]; const llama_seq_id seq_id = ubatch.seq_id[j][s];
@@ -4729,13 +4727,13 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
f = 0.0f; f = 0.0f;
} }
} }
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
} }
} }
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_output_enc; ++j) { for (int j = 0; j < n_enc; ++j) {
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
} }
} }
} }
@@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() {
ggml_tensor * llama_context_dec::build_inp_cross_embd( ggml_tensor * llama_context_dec::build_inp_cross_embd(
ggml_context * ctx0) { ggml_context * ctx0) {
// if we have the output embeddings from the encoder, use them directly
if (cross->t_embd) {
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
return inp.cross_embd;
}
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
const int32_t n_outputs_enc = cross->n_outputs; const auto n_embd = hparams.n_embd;
const auto n_enc = hparams.n_ctx_train;
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
ggml_set_input(inp.cross_embd); ggml_set_input(inp.cross_embd);
return inp.cross_embd; return inp.cross_embd;
@@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp(
bool swa) { bool swa) {
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa); llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
const int32_t n_outputs_enc = cross->n_outputs; const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train;
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(inp.cross_kq_mask); ggml_set_input(inp.cross_kq_mask);
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask; inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;

View File

@@ -748,11 +748,12 @@ private:
llama_kv_cache_recurrent kv_self; llama_kv_cache_recurrent kv_self;
}; };
// TODO: tmp - need something better // TODO: tmp - need something better to pass the data from the encoder to the decoder
struct llama_cross { struct llama_cross {
int32_t n_outputs; // the output embeddings from the encoder
float * embd_enc; ggml_tensor * t_embd = nullptr;
// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc; std::vector<std::set<llama_seq_id>> seq_ids_enc;
}; };