mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : pass embeddings tensor from encoder to decoder
ggml-ci
This commit is contained in:
@@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_
|
|||||||
// llama_context_enc
|
// llama_context_enc
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// TODO: avoid copy-paste of the entire encode() function
|
||||||
int llama_context_enc::encode(llama_batch & inp_batch) {
|
int llama_context_enc::encode(llama_batch & inp_batch) {
|
||||||
if (inp_batch.n_tokens == 0) {
|
if (inp_batch.n_tokens == 0) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||||
@@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
|||||||
// overlap with device computation.
|
// overlap with device computation.
|
||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
cross->n_outputs = n_tokens;
|
cross->t_embd = t_embd;
|
||||||
cross->embd_enc = embd;
|
|
||||||
|
|
||||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||||
cross->seq_ids_enc.resize(n_tokens);
|
cross->seq_ids_enc.resize(n_tokens);
|
||||||
@@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
void llama_context_dec::reserve() {
|
void llama_context_dec::reserve() {
|
||||||
// simulate full KV cache
|
// simulate full KV cache
|
||||||
cross->n_outputs = cparams.n_ubatch;
|
cross->t_embd = nullptr;
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_outputs = %u\n", __func__, cross->n_outputs);
|
|
||||||
|
|
||||||
llama_context_kv_self::reserve();
|
llama_context_kv_self::reserve();
|
||||||
}
|
}
|
||||||
@@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
|||||||
// call base functionality
|
// call base functionality
|
||||||
llama_context_kv_self::input_set(ubatch);
|
llama_context_kv_self::input_set(ubatch);
|
||||||
|
|
||||||
if (inp.cross_embd) {
|
//if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
|
||||||
assert(inp.cross_embd->type == GGML_TYPE_F32);
|
// assert(inp.cross_embd->type == GGML_TYPE_F32);
|
||||||
assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
|
// assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
|
||||||
|
|
||||||
ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
|
// ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
|
||||||
}
|
//}
|
||||||
|
|
||||||
if (inp.cross_kq_mask) {
|
if (inp.cross_kq_mask) {
|
||||||
const int64_t n_output_enc = cross->n_outputs;
|
const int64_t n_enc = inp.cross_kq_mask->ne[0];
|
||||||
const int64_t n_tokens = ubatch.n_tokens;
|
const int64_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
|
||||||
@@ -4721,7 +4719,7 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
for (int i = 0; i < n_output_enc; ++i) {
|
for (int i = 0; i < n_enc; ++i) {
|
||||||
float f = -INFINITY;
|
float f = -INFINITY;
|
||||||
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
|
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[j][s];
|
const llama_seq_id seq_id = ubatch.seq_id[j][s];
|
||||||
@@ -4729,13 +4727,13 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
|||||||
f = 0.0f;
|
f = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
|
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||||
for (int j = 0; j < n_output_enc; ++j) {
|
for (int j = 0; j < n_enc; ++j) {
|
||||||
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
|
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() {
|
|||||||
|
|
||||||
ggml_tensor * llama_context_dec::build_inp_cross_embd(
|
ggml_tensor * llama_context_dec::build_inp_cross_embd(
|
||||||
ggml_context * ctx0) {
|
ggml_context * ctx0) {
|
||||||
|
// if we have the output embeddings from the encoder, use them directly
|
||||||
|
if (cross->t_embd) {
|
||||||
|
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
|
||||||
|
|
||||||
|
return inp.cross_embd;
|
||||||
|
}
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const int64_t n_embd = hparams.n_embd;
|
|
||||||
|
|
||||||
const int32_t n_outputs_enc = cross->n_outputs;
|
const auto n_embd = hparams.n_embd;
|
||||||
|
const auto n_enc = hparams.n_ctx_train;
|
||||||
|
|
||||||
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
|
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
||||||
ggml_set_input(inp.cross_embd);
|
ggml_set_input(inp.cross_embd);
|
||||||
|
|
||||||
return inp.cross_embd;
|
return inp.cross_embd;
|
||||||
@@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp(
|
|||||||
bool swa) {
|
bool swa) {
|
||||||
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
|
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
|
||||||
|
|
||||||
const int32_t n_outputs_enc = cross->n_outputs;
|
const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train;
|
||||||
|
|
||||||
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
ggml_set_input(inp.cross_kq_mask);
|
ggml_set_input(inp.cross_kq_mask);
|
||||||
|
|
||||||
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;
|
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;
|
||||||
|
|||||||
@@ -748,11 +748,12 @@ private:
|
|||||||
llama_kv_cache_recurrent kv_self;
|
llama_kv_cache_recurrent kv_self;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: tmp - need something better
|
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
||||||
struct llama_cross {
|
struct llama_cross {
|
||||||
int32_t n_outputs;
|
// the output embeddings from the encoder
|
||||||
float * embd_enc;
|
ggml_tensor * t_embd = nullptr;
|
||||||
|
|
||||||
|
// needed to construct the cross-attention mask in the decoder
|
||||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user