mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
context : move encode/decode to llama-context.cpp
This commit is contained in:
@@ -3980,6 +3980,31 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
|
||||||
|
int32_t llama_encode(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct llama_batch batch) {
|
||||||
|
const int ret = ctx->encode(batch);
|
||||||
|
if (ret != 0) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_decode(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct llama_batch batch) {
|
||||||
|
const int ret = ctx->decode(batch);
|
||||||
|
if (ret != 0) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||||
struct llama_context * ctx
|
struct llama_context * ctx
|
||||||
) {
|
) {
|
||||||
|
|||||||
@@ -45,7 +45,30 @@ struct llama_context {
|
|||||||
|
|
||||||
virtual ggml_context_ptr init();
|
virtual ggml_context_ptr init();
|
||||||
|
|
||||||
|
// decode a batch of tokens by evaluating the transformer
|
||||||
|
// in case of unsuccessful decoding (error or warning),
|
||||||
|
// the kv_cache state will be returned to its original state
|
||||||
|
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||||
|
//
|
||||||
|
// - lctx: llama context
|
||||||
|
// - inp_batch: batch to evaluate
|
||||||
|
//
|
||||||
|
// return 0 on success
|
||||||
|
// return positive int on warning
|
||||||
|
// return negative int on error
|
||||||
|
//
|
||||||
virtual int decode(llama_batch & inp_batch) = 0;
|
virtual int decode(llama_batch & inp_batch) = 0;
|
||||||
|
|
||||||
|
|
||||||
|
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||||
|
//
|
||||||
|
// - lctx: llama context
|
||||||
|
// - batch: batch to evaluate
|
||||||
|
//
|
||||||
|
// return 0 on success
|
||||||
|
// return positive int on warning
|
||||||
|
// return negative int on error
|
||||||
|
//
|
||||||
virtual int encode(llama_batch & inp_batch) = 0;
|
virtual int encode(llama_batch & inp_batch) = 0;
|
||||||
|
|
||||||
// graph build API (generic)
|
// graph build API (generic)
|
||||||
|
|||||||
@@ -7401,39 +7401,6 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// decode a batch of tokens by evaluating the transformer
|
|
||||||
// in case of unsuccessful decoding (error or warning),
|
|
||||||
// the kv_cache state will be returned to its original state
|
|
||||||
// (for non-recurrent models) or cleaned (for recurrent models)
|
|
||||||
//
|
|
||||||
// - lctx: llama context
|
|
||||||
// - inp_batch: batch to evaluate
|
|
||||||
//
|
|
||||||
// return 0 on success
|
|
||||||
// return positive int on warning
|
|
||||||
// return negative int on error
|
|
||||||
//
|
|
||||||
static int llama_decode_impl(
|
|
||||||
llama_context & lctx,
|
|
||||||
llama_batch inp_batch) {
|
|
||||||
return lctx.decode(inp_batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
|
||||||
//
|
|
||||||
// - lctx: llama context
|
|
||||||
// - batch: batch to evaluate
|
|
||||||
//
|
|
||||||
// return 0 on success
|
|
||||||
// return positive int on warning
|
|
||||||
// return negative int on error
|
|
||||||
//
|
|
||||||
static int llama_encode_impl(
|
|
||||||
llama_context & lctx,
|
|
||||||
llama_batch inp_batch) {
|
|
||||||
return lctx.encode(inp_batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// interface implementation
|
// interface implementation
|
||||||
//
|
//
|
||||||
@@ -7759,30 +7726,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
return llama_init_from_model(model, params);
|
return llama_init_from_model(model, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
///
|
|
||||||
|
|
||||||
int32_t llama_encode(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
struct llama_batch batch) {
|
|
||||||
const int ret = llama_encode_impl(*ctx, batch);
|
|
||||||
if (ret != 0) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_decode(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
struct llama_batch batch) {
|
|
||||||
const int ret = llama_decode_impl(*ctx, batch);
|
|
||||||
if (ret != 0) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// chat templates
|
// chat templates
|
||||||
//
|
//
|
||||||
|
|||||||
Reference in New Issue
Block a user