mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
model: Add support for CogVLM model (#15002)
* Added GGUF mappings for CogVLM model * Add tensor mapping for CogVLM visual encoder * Add CogVLM to conversion script, no vision part yet * Added CogVLM vision model to conversion script * Add graph for CogVLM CLIP model * Add graph for CogVLM * Fixes for CogVLM. Now compiles. * Model now runs * Fixes for cogvlm graph * Account for graph context change after rebase * Changes for whitespace * Changes in convert script according to comments * Switch CogVLM LLM graph to merged QKV tensor * Use rope_type variable instead of direct definition * Change CogVLM CLIP encoder to use SWIGLU * Switch CogVLM CLIP to use merged QKV * Apply rebase edits and remove ggml_cont call that is now unnecessary * clean up --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
@@ -214,6 +214,8 @@ struct clip_layer {
|
||||
ggml_tensor * q_b = nullptr;
|
||||
ggml_tensor * v_w = nullptr;
|
||||
ggml_tensor * v_b = nullptr;
|
||||
ggml_tensor * qkv_w = nullptr;
|
||||
ggml_tensor * qkv_b = nullptr;
|
||||
|
||||
ggml_tensor * o_w = nullptr;
|
||||
ggml_tensor * o_b = nullptr;
|
||||
@@ -286,8 +288,6 @@ struct clip_model {
|
||||
// GLMV-Edge projection
|
||||
ggml_tensor * mm_model_adapter_conv_w = nullptr;
|
||||
ggml_tensor * mm_model_adapter_conv_b = nullptr;
|
||||
ggml_tensor * mm_glm_tok_boi = nullptr;
|
||||
ggml_tensor * mm_glm_tok_eoi = nullptr;
|
||||
|
||||
// MobileVLM projection
|
||||
ggml_tensor * mm_model_mlp_1_w = nullptr;
|
||||
@@ -359,6 +359,15 @@ struct clip_model {
|
||||
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||
|
||||
// cogvlm
|
||||
ggml_tensor * mm_post_fc_norm_w = nullptr;
|
||||
ggml_tensor * mm_post_fc_norm_b = nullptr;
|
||||
ggml_tensor * mm_h_to_4h_w = nullptr;
|
||||
ggml_tensor * mm_gate_w = nullptr;
|
||||
ggml_tensor * mm_4h_to_h_w = nullptr;
|
||||
ggml_tensor * mm_boi = nullptr;
|
||||
ggml_tensor * mm_eoi = nullptr;
|
||||
|
||||
bool audio_has_avgpool() const {
|
||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||
@@ -1494,8 +1503,8 @@ struct clip_graph {
|
||||
// note: these embeddings are not present in text model, hence we cannot process them as text tokens
|
||||
// see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
|
||||
{
|
||||
embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI
|
||||
embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI
|
||||
embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI
|
||||
embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1613,6 +1622,104 @@ struct clip_graph {
|
||||
return gf;
|
||||
}
|
||||
|
||||
// cogvlm vision encoder
|
||||
ggml_cgraph * build_cogvlm() {
|
||||
GGML_ASSERT(model.class_embedding != nullptr);
|
||||
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||
|
||||
const int n_pos = n_patches + 1; // +1 for [CLS]
|
||||
|
||||
// build input and concatenate class embedding
|
||||
ggml_tensor * inp = build_inp();
|
||||
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
|
||||
|
||||
inp = ggml_add(ctx0, inp, model.position_embeddings);
|
||||
cb(inp, "inp_pos", -1);
|
||||
|
||||
ggml_tensor * inpL = inp;
|
||||
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
auto & layer = model.layers[il];
|
||||
ggml_tensor * cur = inpL;
|
||||
|
||||
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||
|
||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||
cur->nb[1], 0);
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||
cur->nb[1], n_embd * sizeof(float));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
|
||||
cur->nb[1], 2 * n_embd * sizeof(float));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(layer.o_w, layer.o_b,
|
||||
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
|
||||
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
inpL = cur;
|
||||
|
||||
cur = build_ffn(cur,
|
||||
layer.ff_up_w, layer.ff_up_b,
|
||||
layer.ff_gate_w, layer.ff_gate_b,
|
||||
layer.ff_down_w, layer.ff_down_b,
|
||||
hparams.ffn_op, il);
|
||||
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
cb(cur, "layer_out", il);
|
||||
inpL = cur;
|
||||
|
||||
}
|
||||
|
||||
// remove CLS token (like build_llama4 does)
|
||||
ggml_tensor * cur = ggml_view_2d(ctx0, inpL,
|
||||
n_embd, n_patches,
|
||||
ggml_row_size(inpL->type, n_embd), 0);
|
||||
|
||||
// Multiply with mm_model_proj
|
||||
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
|
||||
|
||||
// Apply layernorm, weight, bias
|
||||
cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
|
||||
|
||||
// Apply GELU
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
|
||||
// Branch 1: multiply with mm_h_to_4h_w
|
||||
ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur);
|
||||
|
||||
// Branch 2: multiply with mm_gate_w
|
||||
ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur);
|
||||
|
||||
// Apply silu
|
||||
gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
|
||||
|
||||
// Apply mm_4h_to_h_w
|
||||
cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate);
|
||||
|
||||
// Concatenate with boi and eoi
|
||||
cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
|
||||
cur = ggml_concat(ctx0, cur, model.mm_eoi, 1);
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
private:
|
||||
//
|
||||
// utility functions
|
||||
@@ -2126,6 +2233,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
res = graph.build_kimivl();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
{
|
||||
res = graph.build_cogvlm();
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
res = graph.build_llava();
|
||||
@@ -2532,10 +2643,11 @@ struct clip_model_loader {
|
||||
model.layers.resize(hparams.n_layer);
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & layer = model.layers[il];
|
||||
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
|
||||
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
|
||||
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
|
||||
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"), false);
|
||||
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"), false);
|
||||
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"), false);
|
||||
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
|
||||
layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false);
|
||||
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
|
||||
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
|
||||
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
|
||||
@@ -2547,6 +2659,7 @@ struct clip_model_loader {
|
||||
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
|
||||
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
|
||||
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
|
||||
layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false);
|
||||
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
|
||||
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
|
||||
|
||||
@@ -2682,8 +2795,8 @@ struct clip_model_loader {
|
||||
model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
|
||||
model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
|
||||
model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
|
||||
model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
|
||||
model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
|
||||
model.mm_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
|
||||
model.mm_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
@@ -2777,6 +2890,17 @@ struct clip_model_loader {
|
||||
model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
|
||||
model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
{
|
||||
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
|
||||
model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight"));
|
||||
model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias"));
|
||||
model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight"));
|
||||
model.mm_gate_w = get_tensor(string_format(TN_MM_GATE, "weight"));
|
||||
model.mm_4h_to_h_w = get_tensor(string_format(TN_MM_4H_TO_H, "weight"));
|
||||
model.mm_boi = get_tensor(TN_TOK_BOI);
|
||||
model.mm_eoi = get_tensor(TN_TOK_EOI);
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown projector type");
|
||||
}
|
||||
@@ -3825,7 +3949,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
{
|
||||
n_patches /= 4;
|
||||
if (ctx->model.mm_glm_tok_boi) {
|
||||
if (ctx->model.mm_boi) {
|
||||
n_patches += 2; // for BOI and EOI token embeddings
|
||||
}
|
||||
} break;
|
||||
@@ -3915,6 +4039,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
n_patches /= 2;
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
{
|
||||
n_patches += 2; // for BOI and EOI token embeddings
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported projector type");
|
||||
}
|
||||
@@ -4323,6 +4451,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
{
|
||||
// do nothing
|
||||
} break;
|
||||
@@ -4427,6 +4556,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
return ctx->model.mm_4h_to_h_w->ne[1];
|
||||
default:
|
||||
GGML_ABORT("Unknown projector type");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user