diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a84283eb48..26a5cf9c3f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; const int64_t n_embd = hparams.n_embd; - const int32_t n_vocab = model.vocab.n_tokens(); + const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { @@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto & vocab = model.vocab; const auto & hparams = model.hparams; - const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd; // when computing embeddings, all tokens are output