mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : move build_rope_factors to base class
ggml-ci
This commit is contained in:
@@ -57,6 +57,10 @@ uint32_t llama_context::n_ctx() const {
|
|||||||
return cparams.n_ctx;
|
return cparams.n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_context::n_ctx_per_seq() const {
|
||||||
|
return cparams.n_ctx / cparams.n_seq_max;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_context::n_batch() const {
|
uint32_t llama_context::n_batch() const {
|
||||||
return cparams.n_batch;
|
return cparams.n_batch;
|
||||||
}
|
}
|
||||||
@@ -202,6 +206,86 @@ llama_perf_context_data llama_context::perf_get_data() const {
|
|||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_context::build_cvec(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
int il) {
|
||||||
|
return cvec.apply_to(ctx0, cur, il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_context::build_lora_mm(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * w,
|
||||||
|
ggml_tensor * cur) {
|
||||||
|
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
||||||
|
|
||||||
|
for (const auto & lora : loras) {
|
||||||
|
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
|
||||||
|
if (lw == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float adapter_scale = lora.second;
|
||||||
|
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
|
||||||
|
|
||||||
|
struct ggml_tensor * ab_cur = ggml_mul_mat(
|
||||||
|
ctx0, lw->b,
|
||||||
|
ggml_mul_mat(ctx0, lw->a, cur)
|
||||||
|
);
|
||||||
|
|
||||||
|
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
||||||
|
res = ggml_add(ctx0, res, ab_cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_context::build_lora_mm_id(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * w,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * ids) {
|
||||||
|
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
|
||||||
|
for (const auto & lora : loras) {
|
||||||
|
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
|
||||||
|
if (lw == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float alpha = lora.first->alpha;
|
||||||
|
const float rank = (float) lw->b->ne[0];
|
||||||
|
const float scale = alpha ? lora.second * alpha / rank : lora.second;
|
||||||
|
|
||||||
|
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
|
||||||
|
ctx0, lw->b,
|
||||||
|
ggml_mul_mat_id(ctx0, lw->a, cur, ids),
|
||||||
|
ids
|
||||||
|
);
|
||||||
|
|
||||||
|
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
||||||
|
res = ggml_add(ctx0, res, ab_cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_context::build_rope_factors(int il) {
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
// choose long/short freq factors based on the context size
|
||||||
|
const auto n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||||
|
|
||||||
|
if (model.layers[il].rope_freqs != nullptr) {
|
||||||
|
return model.layers[il].rope_freqs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
|
||||||
|
return model.layers[il].rope_long;
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.layers[il].rope_short;
|
||||||
|
}
|
||||||
|
|
||||||
void llama_context::perf_reset() {
|
void llama_context::perf_reset() {
|
||||||
t_start_us = ggml_time_us();
|
t_start_us = ggml_time_us();
|
||||||
t_eval_us = n_eval = 0;
|
t_eval_us = n_eval = 0;
|
||||||
@@ -1825,69 +1909,6 @@ size_t llama_context_unified::reserve_outputs(size_t n_outputs) {
|
|||||||
return n_outputs_max;
|
return n_outputs_max;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_context::build_cvec(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_tensor * cur,
|
|
||||||
int il) {
|
|
||||||
return cvec.apply_to(ctx0, cur, il);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llama_context::build_lora_mm(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_tensor * w,
|
|
||||||
ggml_tensor * cur) {
|
|
||||||
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
|
|
||||||
|
|
||||||
for (const auto & lora : loras) {
|
|
||||||
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
|
|
||||||
if (lw == nullptr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float adapter_scale = lora.second;
|
|
||||||
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
|
|
||||||
|
|
||||||
struct ggml_tensor * ab_cur = ggml_mul_mat(
|
|
||||||
ctx0, lw->b,
|
|
||||||
ggml_mul_mat(ctx0, lw->a, cur)
|
|
||||||
);
|
|
||||||
|
|
||||||
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
|
||||||
res = ggml_add(ctx0, res, ab_cur);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llama_context::build_lora_mm_id(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_tensor * w,
|
|
||||||
ggml_tensor * cur,
|
|
||||||
ggml_tensor * ids) {
|
|
||||||
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
|
|
||||||
for (const auto & lora : loras) {
|
|
||||||
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
|
|
||||||
if (lw == nullptr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float alpha = lora.first->alpha;
|
|
||||||
const float rank = (float) lw->b->ne[0];
|
|
||||||
const float scale = alpha ? lora.second * alpha / rank : lora.second;
|
|
||||||
|
|
||||||
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
|
|
||||||
ctx0, lw->b,
|
|
||||||
ggml_mul_mat_id(ctx0, lw->a, cur, ids),
|
|
||||||
ids
|
|
||||||
);
|
|
||||||
|
|
||||||
ab_cur = ggml_scale(ctx0, ab_cur, scale);
|
|
||||||
res = ggml_add(ctx0, res, ab_cur);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_context_unified::kv_self_update() {
|
void llama_context_unified::kv_self_update() {
|
||||||
auto & kv = kv_self;
|
auto & kv = kv_self;
|
||||||
|
|
||||||
@@ -2189,23 +2210,6 @@ ggml_tensor * llama_context_unified::build_soft_max_ext(
|
|||||||
return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias);
|
return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_context_unified::get_rope_factors(int il) {
|
|
||||||
const auto & hparams = model.hparams;
|
|
||||||
|
|
||||||
// choose long/short freq factors based on the context size
|
|
||||||
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
||||||
|
|
||||||
if (model.layers[il].rope_freqs != nullptr) {
|
|
||||||
return model.layers[il].rope_freqs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
|
|
||||||
return model.layers[il].rope_long;
|
|
||||||
}
|
|
||||||
|
|
||||||
return model.layers[il].rope_short;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llama_context_unified::build_inp_embd(
|
ggml_tensor * llama_context_unified::build_inp_embd(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_tensor * tok_embd,
|
ggml_tensor * tok_embd,
|
||||||
@@ -2327,7 +2331,7 @@ void llama_context_unified::build_k_shift(
|
|||||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
|
||||||
struct ggml_tensor * rope_factors = get_rope_factors(il);
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||||
|
|
||||||
struct ggml_tensor * k =
|
struct ggml_tensor * k =
|
||||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ struct llama_context {
|
|||||||
const llama_cparams & get_cparams() const;
|
const llama_cparams & get_cparams() const;
|
||||||
|
|
||||||
virtual uint32_t n_ctx() const;
|
virtual uint32_t n_ctx() const;
|
||||||
|
virtual uint32_t n_ctx_per_seq() const;
|
||||||
virtual uint32_t n_batch() const;
|
virtual uint32_t n_batch() const;
|
||||||
virtual uint32_t n_ubatch() const;
|
virtual uint32_t n_ubatch() const;
|
||||||
virtual uint32_t n_seq_max() const = 0;
|
virtual uint32_t n_seq_max() const = 0;
|
||||||
@@ -126,6 +127,8 @@ struct llama_context {
|
|||||||
ggml_tensor * cur, // struct ggml_tensor * b
|
ggml_tensor * cur, // struct ggml_tensor * b
|
||||||
ggml_tensor * ids);
|
ggml_tensor * ids);
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_rope_factors(int il);
|
||||||
|
|
||||||
// graph build API (context-specific)
|
// graph build API (context-specific)
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_embd(
|
virtual ggml_tensor * build_inp_embd(
|
||||||
@@ -182,8 +185,6 @@ struct llama_context {
|
|||||||
ggml_tensor * kq,
|
ggml_tensor * kq,
|
||||||
float kq_scale) = 0;
|
float kq_scale) = 0;
|
||||||
|
|
||||||
virtual ggml_tensor * get_rope_factors(int il) = 0;
|
|
||||||
|
|
||||||
virtual void build_k_shift(
|
virtual void build_k_shift(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * graph) = 0;
|
ggml_cgraph * graph) = 0;
|
||||||
@@ -342,7 +343,7 @@ class llama_context_unified : public llama_context {
|
|||||||
public:
|
public:
|
||||||
struct batch_manager;
|
struct batch_manager;
|
||||||
|
|
||||||
// TODO: tmp until llama-model starts implementing the graph build function
|
// TODO: tmp until llama_model starts implementing the graph build function
|
||||||
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
|
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
|
||||||
|
|
||||||
llama_context_unified(
|
llama_context_unified(
|
||||||
@@ -496,8 +497,6 @@ public:
|
|||||||
ggml_tensor * kq,
|
ggml_tensor * kq,
|
||||||
float kq_scale) override;
|
float kq_scale) override;
|
||||||
|
|
||||||
virtual ggml_tensor * get_rope_factors(int il) override;
|
|
||||||
|
|
||||||
virtual void build_k_shift(
|
virtual void build_k_shift(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * graph) override;
|
ggml_cgraph * graph) override;
|
||||||
|
|||||||
@@ -685,7 +685,7 @@ struct llm_build_context {
|
|||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
|
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
|
||||||
|
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
@@ -857,7 +857,7 @@ struct llm_build_context {
|
|||||||
} else if (n_head > 0) {
|
} else if (n_head > 0) {
|
||||||
// self-attention
|
// self-attention
|
||||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
|
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
|
||||||
|
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
@@ -2999,7 +2999,7 @@ struct llm_build_context {
|
|||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
// rope freq factors for 128k context
|
// rope freq factors for 128k context
|
||||||
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
|
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
|
||||||
|
|
||||||
struct ggml_tensor* attn_norm_output = build_norm(inpL,
|
struct ggml_tensor* attn_norm_output = build_norm(inpL,
|
||||||
model.layers[il].attn_norm,
|
model.layers[il].attn_norm,
|
||||||
@@ -3706,7 +3706,7 @@ struct llm_build_context {
|
|||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
|
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
|
||||||
// norm
|
// norm
|
||||||
cur = build_norm(inpL,
|
cur = build_norm(inpL,
|
||||||
model.layers[il].attn_norm, NULL,
|
model.layers[il].attn_norm, NULL,
|
||||||
@@ -4480,7 +4480,7 @@ struct llm_build_context {
|
|||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
// rope freq factors for 128k context
|
// rope freq factors for 128k context
|
||||||
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
|
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
|
||||||
|
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
@@ -5373,7 +5373,7 @@ struct llm_build_context {
|
|||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
|
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
|
||||||
|
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
@@ -6572,7 +6572,7 @@ struct llm_build_context {
|
|||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||||
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
|
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
|
||||||
|
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
|||||||
Reference in New Issue
Block a user