mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
model : add LightOnOCR-1B model (#16764)
* model : add LightOnOCR-1B model * add test
This commit is contained in:
@@ -2460,17 +2460,20 @@ class ArceeModel(LlamaModel):
|
||||
)
|
||||
class LlavaVisionModel(MmprojModel):
|
||||
img_break_tok_id = -1
|
||||
use_break_tok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.hparams.get("model_type") == "pixtral":
|
||||
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||
elif self.is_mistral_format:
|
||||
# hparams is already vision config here so norm_eps is only defined in global_config.
|
||||
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
|
||||
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
||||
@@ -3962,6 +3965,10 @@ class Qwen3Model(Qwen2Model):
|
||||
return torch.stack([true_row, false_row], dim=0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "model.vision_" in name:
|
||||
# skip multimodal tensors
|
||||
return []
|
||||
|
||||
if self.is_rerank:
|
||||
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
|
||||
is_real_head = not self.is_tied_embeddings and "lm_head" in name
|
||||
@@ -9435,6 +9442,21 @@ class PixtralModel(LlavaVisionModel):
|
||||
return super().map_tensor_name(name, try_suffixes)
|
||||
|
||||
|
||||
@ModelBase.register("LightOnOCRForConditionalGeneration")
|
||||
class LightOnOCRVisionModel(LlavaVisionModel):
|
||||
is_mistral_format = False
|
||||
use_break_tok = False
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("model.vision_encoder.", "vision_tower.")
|
||||
name = name.replace("model.vision_projection.", "multi_modal_projector.")
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
class KimiVLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -3062,6 +3062,7 @@ class VisionProjectorType:
|
||||
VOXTRAL = "voxtral"
|
||||
LFM2 = "lfm2"
|
||||
KIMIVL = "kimivl"
|
||||
LIGHTONOCR = "lightonocr"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -139,6 +139,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
PROJECTOR_TYPE_LFM2,
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
PROJECTOR_TYPE_LIGHTONOCR,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -161,6 +162,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
{ PROJECTOR_TYPE_LFM2, "lfm2"},
|
||||
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
||||
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
||||
@@ -621,7 +621,7 @@ struct clip_graph {
|
||||
}
|
||||
|
||||
// arrangement of the [IMG_BREAK] token
|
||||
{
|
||||
if (model.token_embd_img_break) {
|
||||
// not efficient, but works
|
||||
// the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
|
||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||
@@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
res = graph.build_siglip();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
res = graph.build_pixtral();
|
||||
} break;
|
||||
@@ -2380,6 +2381,7 @@ struct clip_model_loader {
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
@@ -2722,6 +2724,15 @@ struct clip_model_loader {
|
||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
|
||||
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
|
||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||
model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
@@ -3622,7 +3633,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR
|
||||
) {
|
||||
clip_image_u8 resized_image;
|
||||
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
|
||||
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
|
||||
@@ -3865,12 +3878,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
n_patches = x_patch * y_patch;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// dynamic size
|
||||
int n_merge = params.spatial_merge_size;
|
||||
int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
if (ctx->model.token_embd_img_break) {
|
||||
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||
} else {
|
||||
n_patches = n_patches_y * n_patches_x;
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
@@ -4247,6 +4265,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// set the 2D positions
|
||||
int n_patches_per_col = image_size_width / patch_size;
|
||||
@@ -4377,6 +4396,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_model_peg_0_b->ne[0];
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
return ctx->model.mm_3_b->ne[0];
|
||||
|
||||
@@ -275,6 +275,11 @@ struct mtmd_context {
|
||||
img_beg = "<img>";
|
||||
img_end = "</img>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_LIGHTONOCR) {
|
||||
// <|im_start|> ... (image embeddings) ... <|im_end|>
|
||||
img_beg = "<|im_start|>";
|
||||
img_end = "<|im_end|>";
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
|
||||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
|
||||
Reference in New Issue
Block a user