optimize deepstack structure

This commit is contained in:
JJJYmmm
2025-10-28 21:58:46 +08:00
parent 0443a098f3
commit 3271877207
5 changed files with 73 additions and 49 deletions

View File

@@ -4046,7 +4046,9 @@ class Qwen3VLVisionModel(MmprojModel):
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
self.deepstack_layers: list[int] = list(self.hparams_vision.get("deepstack_visual_indexes", [])) self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
self.is_deepstack_layers[idx] = True
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
@@ -4062,10 +4064,11 @@ class Qwen3VLVisionModel(MmprojModel):
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6) rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps) self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
if self.deepstack_layers: if self.is_deepstack_layers:
self.gguf_writer.add_vision_deepstack_layers(self.deepstack_layers) self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
assert self.hparams_vision is not None
# Skip text model tensors - they go in the text model file # Skip text model tensors - they go in the text model file
if name.startswith("model.language_model.") or name.startswith("lm_head."): if name.startswith("model.language_model.") or name.startswith("lm_head."):
return [] return []
@@ -4075,7 +4078,8 @@ class Qwen3VLVisionModel(MmprojModel):
if name.startswith("visual.deepstack_merger_list."): if name.startswith("visual.deepstack_merger_list."):
prefix, rest = name.split(".", maxsplit=3)[2:] prefix, rest = name.split(".", maxsplit=3)[2:]
idx = int(prefix) # prefix is the layer index, convert to absolute clip layer index!
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
target = rest target = rest
tensor_type: gguf.MODEL_TENSOR tensor_type: gguf.MODEL_TENSOR

View File

@@ -278,7 +278,7 @@ class Keys:
USE_GELU = "clip.use_gelu" USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu" USE_SILU = "clip.use_silu"
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
DEEPSTACK_LAYERS = "clip.vision.deepstack_layers" IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
class Attention: class Attention:
HEAD_COUNT = "clip.vision.attention.head_count" HEAD_COUNT = "clip.vision.attention.head_count"

View File

@@ -1074,8 +1074,8 @@ class GGUFWriter:
def add_vision_n_wa_pattern(self, value: int) -> None: def add_vision_n_wa_pattern(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
def add_vision_deepstack_layers(self, layers: Sequence[int]) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
self.add_array(Keys.ClipVision.DEEPSTACK_LAYERS, layers) self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
# audio models # audio models

View File

@@ -39,7 +39,7 @@
#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_DEEPSTACK_LAYERS "clip.vision.deepstack_layers" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -94,6 +94,9 @@
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model) #define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack
#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack
#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack
// mimicpmv // mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"

View File

@@ -196,7 +196,7 @@ struct clip_hparams {
int32_t n_wa_pattern = 0; int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0; int32_t spatial_merge_size = 0;
std::vector<int32_t> deepstack_layers; // qwen3vl deepstack layers std::vector<bool> is_deepstack_layers; // qwen3vl: whether the layer is a deepstack layer
// audio // audio
int32_t n_mel_bins = 0; // whisper preprocessor int32_t n_mel_bins = 0; // whisper preprocessor
@@ -241,6 +241,14 @@ struct clip_layer {
// layer scale (no bias) // layer scale (no bias)
ggml_tensor * ls_1_w = nullptr; ggml_tensor * ls_1_w = nullptr;
ggml_tensor * ls_2_w = nullptr; ggml_tensor * ls_2_w = nullptr;
// qwen3vl deepstack merger
ggml_tensor * deepstack_norm_w = nullptr;
ggml_tensor * deepstack_norm_b = nullptr;
ggml_tensor * deepstack_fc1_w = nullptr;
ggml_tensor * deepstack_fc1_b = nullptr;
ggml_tensor * deepstack_fc2_w = nullptr;
ggml_tensor * deepstack_fc2_b = nullptr;
}; };
struct clip_model { struct clip_model {
@@ -361,17 +369,6 @@ struct clip_model {
ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr; ggml_tensor * mm_norm_mid_w = nullptr;
// qwen3vl deepstack
struct deepstack_merger {
ggml_tensor * norm_w = nullptr;
ggml_tensor * norm_b = nullptr;
ggml_tensor * fc1_w = nullptr;
ggml_tensor * fc1_b = nullptr;
ggml_tensor * fc2_w = nullptr;
ggml_tensor * fc2_b = nullptr;
};
std::vector<deepstack_merger> deepstack_mergers;
bool audio_has_avgpool() const { bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL; || proj_type == PROJECTOR_TYPE_VOXTRAL;
@@ -849,7 +846,6 @@ struct clip_graph {
GGML_ASSERT(model.patch_bias != nullptr); GGML_ASSERT(model.patch_bias != nullptr);
GGML_ASSERT(model.position_embeddings != nullptr); GGML_ASSERT(model.position_embeddings != nullptr);
GGML_ASSERT(model.class_embedding == nullptr); GGML_ASSERT(model.class_embedding == nullptr);
GGML_ASSERT(!hparams.deepstack_layers.empty());
const int batch_size = 1; const int batch_size = 1;
const int n_pos = n_patches; const int n_pos = n_patches;
@@ -986,17 +982,13 @@ struct clip_graph {
cur = ggml_add(ctx0, inpL, cur); cur = ggml_add(ctx0, inpL, cur);
cb(cur, "layer_out", il); cb(cur, "layer_out", il);
if (std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) != hparams.deepstack_layers.end()) { if (hparams.is_deepstack_layers[il]) {
const int deepstack_idx = std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) - hparams.deepstack_layers.begin(); ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size);
auto & merger = model.deepstack_mergers[deepstack_idx]; feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il);
ggml_tensor * feat = ggml_dup(ctx0, cur);
feat = ggml_reshape_3d(ctx0, feat, n_embd * merge_factor, n_pos / merge_factor, batch_size);
feat = build_norm(feat, merger.norm_w, merger.norm_b, norm_t, eps, il);
feat = build_ffn(feat, feat = build_ffn(feat,
merger.fc1_w, merger.fc1_b, layer.deepstack_fc1_w, layer.deepstack_fc1_b,
nullptr, nullptr, nullptr, nullptr,
merger.fc2_w, merger.fc2_b, layer.deepstack_fc2_w, layer.deepstack_fc2_b,
ffn_op_type::FFN_GELU, il); ffn_op_type::FFN_GELU, il);
if(!deepstack_features) { if(!deepstack_features) {
@@ -2571,6 +2563,9 @@ struct clip_model_loader {
hparams.vision_feature_layer.insert(layer); hparams.vision_feature_layer.insert(layer);
} }
// set default deepstack layers to false
hparams.is_deepstack_layers.resize(hparams.n_layer, false);
// model-specific params // model-specific params
switch (model.proj_type) { switch (model.proj_type) {
case PROJECTOR_TYPE_MINICPMV: case PROJECTOR_TYPE_MINICPMV:
@@ -2632,7 +2627,7 @@ struct clip_model_loader {
hparams.image_size = 1024; // still need this? hparams.image_size = 1024; // still need this?
hparams.warmup_image_size = hparams.patch_size * 8; hparams.warmup_image_size = hparams.patch_size * 8;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
get_arr_int(KEY_DEEPSTACK_LAYERS, hparams.deepstack_layers, false); get_arr_bool(KEY_IS_DEEPSTACK_LAYERS, hparams.is_deepstack_layers, false);
} break; } break;
case PROJECTOR_TYPE_LLAMA4: case PROJECTOR_TYPE_LLAMA4:
{ {
@@ -2675,10 +2670,19 @@ struct clip_model_loader {
if (hparams.spatial_merge_size > 0) { if (hparams.spatial_merge_size > 0) {
LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size); LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size);
} }
if (!hparams.deepstack_layers.empty()) { if (!hparams.is_deepstack_layers.empty()) {
LOG_INF("%s: deepstack_layers: ", __func__); LOG_INF("%s: deepstack enabled layers: ", __func__);
for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) { bool first = true;
LOG_CNT("%d%s", hparams.deepstack_layers[i], i < hparams.deepstack_layers.size() - 1 ? ", " : "\n"); for (size_t i = 0; i < hparams.is_deepstack_layers.size(); ++i) {
if (hparams.is_deepstack_layers[i]) {
LOG_CNT("%s%zu", first ? "" : ", ", i);
first = false;
}
}
if (first) {
LOG_CNT("none\n");
} else {
LOG_CNT("\n");
} }
} }
} else if (is_audio) { } else if (is_audio) {
@@ -2778,6 +2782,17 @@ struct clip_model_loader {
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
// qwen3vl deepstack layer
if (hparams.is_deepstack_layers[il]) {
layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false);
layer.deepstack_norm_b = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "bias"), false);
layer.deepstack_fc1_w = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "weight"), false);
layer.deepstack_fc1_b = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "bias"), false);
layer.deepstack_fc2_w = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "weight"), false);
layer.deepstack_fc2_b = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "bias"), false);
}
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
bool is_ffn_swapped = ( bool is_ffn_swapped = (
@@ -2919,19 +2934,6 @@ struct clip_model_loader {
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
if (!hparams.deepstack_layers.empty()) {
model.deepstack_mergers.resize(hparams.deepstack_layers.size());
for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) {
auto & merger = model.deepstack_mergers[i];
merger.norm_w = get_tensor(string_format("v.deepstack.%d.norm.weight", (int)i), false);
merger.norm_b = get_tensor(string_format("v.deepstack.%d.norm.bias", (int)i), false);
merger.fc1_w = get_tensor(string_format("v.deepstack.%d.fc1.weight", (int)i), false);
merger.fc1_b = get_tensor(string_format("v.deepstack.%d.fc1.bias", (int)i), false);
merger.fc2_w = get_tensor(string_format("v.deepstack.%d.fc2.weight", (int)i), false);
merger.fc2_b = get_tensor(string_format("v.deepstack.%d.fc2.bias", (int)i), false);
}
}
} break; } break;
case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3:
{ {
@@ -3139,6 +3141,21 @@ struct clip_model_loader {
} }
} }
void get_arr_bool(const std::string & key, std::vector<bool> & output, bool required = true) {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
return;
}
const int n = gguf_get_arr_n(ctx_gguf.get(), i);
output.resize(n);
const bool * values = (const bool *)gguf_get_arr_data(ctx_gguf.get(), i);
for (int i = 0; i < n; ++i) {
output[i] = values[i];
}
}
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) { void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
auto & hparams = model.hparams; auto & hparams = model.hparams;
for (int x = 1; x <= max_patches_per_side; x++) { for (int x = 1; x <= max_patches_per_side; x++) {
@@ -4632,7 +4649,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN25VL:
return ctx->model.mm_1_b->ne[0]; return ctx->model.mm_1_b->ne[0];
case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_QWEN3VL:
return ctx->model.mm_1_b->ne[0] * ((int)ctx->model.hparams.deepstack_layers.size() + 1); // main path + deepstack paths return ctx->model.mm_1_b->ne[0] * (1 + std::count(ctx->model.hparams.is_deepstack_layers.begin(), ctx->model.hparams.is_deepstack_layers.end(), true)); // main path + deepstack paths
case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3:
return ctx->model.mm_input_proj_w->ne[0]; return ctx->model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_IDEFICS3: