model: Add support for CogVLM model (#15002)

* Added GGUF mappings for CogVLM model

* Add tensor mapping for CogVLM visual encoder

* Add CogVLM to conversion script, no vision part yet

* Added CogVLM vision model to conversion script

* Add graph for CogVLM CLIP model

* Add graph for CogVLM

* Fixes for CogVLM. Now compiles.

* Model now runs

* Fixes for cogvlm graph

* Account for graph context change after rebase

* Changes for whitespace

* Changes in convert script according to comments

* Switch CogVLM LLM graph to merged QKV tensor

* Use rope_type variable instead of direct definition

* Change CogVLM CLIP encoder to use SWIGLU

* Switch CogVLM CLIP to use merged QKV

* Apply rebase edits and remove ggml_cont call that is now unnecessary

* clean up

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Tianyue-Zhao
2025-10-30 07:18:50 -04:00
committed by GitHub
parent 229bf68628
commit bacddc049a
9 changed files with 501 additions and 26 deletions

View File

@@ -214,6 +214,8 @@ struct clip_layer {
ggml_tensor * q_b = nullptr;
ggml_tensor * v_w = nullptr;
ggml_tensor * v_b = nullptr;
ggml_tensor * qkv_w = nullptr;
ggml_tensor * qkv_b = nullptr;
ggml_tensor * o_w = nullptr;
ggml_tensor * o_b = nullptr;
@@ -286,8 +288,6 @@ struct clip_model {
// GLMV-Edge projection
ggml_tensor * mm_model_adapter_conv_w = nullptr;
ggml_tensor * mm_model_adapter_conv_b = nullptr;
ggml_tensor * mm_glm_tok_boi = nullptr;
ggml_tensor * mm_glm_tok_eoi = nullptr;
// MobileVLM projection
ggml_tensor * mm_model_mlp_1_w = nullptr;
@@ -359,6 +359,15 @@ struct clip_model {
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
// cogvlm
ggml_tensor * mm_post_fc_norm_w = nullptr;
ggml_tensor * mm_post_fc_norm_b = nullptr;
ggml_tensor * mm_h_to_4h_w = nullptr;
ggml_tensor * mm_gate_w = nullptr;
ggml_tensor * mm_4h_to_h_w = nullptr;
ggml_tensor * mm_boi = nullptr;
ggml_tensor * mm_eoi = nullptr;
bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
@@ -1494,8 +1503,8 @@ struct clip_graph {
// note: these embeddings are not present in text model, hence we cannot process them as text tokens
// see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
{
embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI
embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI
embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI
embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI
}
}
@@ -1613,6 +1622,104 @@ struct clip_graph {
return gf;
}
// cogvlm vision encoder
ggml_cgraph * build_cogvlm() {
GGML_ASSERT(model.class_embedding != nullptr);
GGML_ASSERT(model.position_embeddings != nullptr);
const int n_pos = n_patches + 1; // +1 for [CLS]
// build input and concatenate class embedding
ggml_tensor * inp = build_inp();
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
inp = ggml_add(ctx0, inp, model.position_embeddings);
cb(inp, "inp_pos", -1);
ggml_tensor * inpL = inp;
for (int il = 0; il < n_layer; il++) {
auto & layer = model.layers[il];
ggml_tensor * cur = inpL;
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], 0);
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], n_embd * sizeof(float));
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], 2 * n_embd * sizeof(float));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
cb(cur, "attn_post_norm", il);
cur = ggml_add(ctx0, cur, inpL);
inpL = cur;
cur = build_ffn(cur,
layer.ff_up_w, layer.ff_up_b,
layer.ff_gate_w, layer.ff_gate_b,
layer.ff_down_w, layer.ff_down_b,
hparams.ffn_op, il);
cb(cur, "ffn_out", il);
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, inpL);
cb(cur, "layer_out", il);
inpL = cur;
}
// remove CLS token (like build_llama4 does)
ggml_tensor * cur = ggml_view_2d(ctx0, inpL,
n_embd, n_patches,
ggml_row_size(inpL->type, n_embd), 0);
// Multiply with mm_model_proj
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
// Apply layernorm, weight, bias
cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
// Apply GELU
cur = ggml_gelu_inplace(ctx0, cur);
// Branch 1: multiply with mm_h_to_4h_w
ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur);
// Branch 2: multiply with mm_gate_w
ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur);
// Apply silu
gate = ggml_swiglu_split(ctx0, gate, h_to_4h);
// Apply mm_4h_to_h_w
cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate);
// Concatenate with boi and eoi
cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
cur = ggml_concat(ctx0, cur, model.mm_eoi, 1);
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}
private:
//
// utility functions
@@ -2126,6 +2233,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_kimivl();
} break;
case PROJECTOR_TYPE_COGVLM:
{
res = graph.build_cogvlm();
} break;
default:
{
res = graph.build_llava();
@@ -2532,10 +2643,11 @@ struct clip_model_loader {
model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = model.layers[il];
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"), false);
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"), false);
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"), false);
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false);
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
@@ -2547,6 +2659,7 @@ struct clip_model_loader {
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
@@ -2682,8 +2795,8 @@ struct clip_model_loader {
model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
model.mm_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
model.mm_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
@@ -2777,6 +2890,17 @@ struct clip_model_loader {
model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
} break;
case PROJECTOR_TYPE_COGVLM:
{
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight"));
model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias"));
model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight"));
model.mm_gate_w = get_tensor(string_format(TN_MM_GATE, "weight"));
model.mm_4h_to_h_w = get_tensor(string_format(TN_MM_4H_TO_H, "weight"));
model.mm_boi = get_tensor(TN_TOK_BOI);
model.mm_eoi = get_tensor(TN_TOK_EOI);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
@@ -3825,7 +3949,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_GLM_EDGE:
{
n_patches /= 4;
if (ctx->model.mm_glm_tok_boi) {
if (ctx->model.mm_boi) {
n_patches += 2; // for BOI and EOI token embeddings
}
} break;
@@ -3915,6 +4039,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches /= 2;
}
} break;
case PROJECTOR_TYPE_COGVLM:
{
n_patches += 2; // for BOI and EOI token embeddings
} break;
default:
GGML_ABORT("unsupported projector type");
}
@@ -4323,6 +4451,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_COGVLM:
{
// do nothing
} break;
@@ -4427,6 +4556,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
}