hparams : add n_embd_inp() to support extended embed (#16928)

* add n_embd_full to support extended embed

* don't change output

* rename to n_embd_inp

* restore n_embd where applicable
This commit is contained in:
Sigbjørn Skjæret
2025-11-07 19:27:58 +01:00
committed by GitHub
parent 16bcc1259d
commit 9008027aa3
9 changed files with 29 additions and 28 deletions

View File

@@ -827,7 +827,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd = hparams.n_embd_inp();
const int64_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
@@ -996,7 +996,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto & hparams = model.hparams;
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd = hparams.n_embd_inp();
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
@@ -2154,7 +2154,7 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true;
}
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return;
}