Merge branch 'master' into gg/llama-kv-cache

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-27 15:10:18 +02:00
100 changed files with 4248 additions and 1355 deletions

View File

@@ -124,15 +124,26 @@ struct ContentView: View {
}
}
}.sheet(isPresented: $showingHelp) { // Sheet for help modal
VStack(alignment: .leading) {
NavigationView {
VStack(alignment: .leading) {
Text("1. Make sure the model is in GGUF Format")
.padding()
Text("2. Copy the download link of the quantized model")
.padding()
VStack(alignment: .leading) {
Text("1. Make sure the model is in GGUF Format")
.padding()
Text("2. Copy the download link of the quantized model")
.padding()
}
Spacer()
}
Spacer()
}
.navigationTitle("Help")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
Button("Done") {
showingHelp = false
}
}
}
}
}
}
}

View File

@@ -0,0 +1,183 @@
# Granite Vision
Download the model and point your `GRANITE_MODEL` environment variable to the path.
```bash
$ git clone https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview
$ export GRANITE_MODEL=./granite-vision-3.1-2b-preview
```
### 1. Running llava surgery v2.
First, we need to run the llava surgery script as shown below:
`python llava_surgery_v2.py -C -m $GRANITE_MODEL`
You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory, as shown below.
```bash
$ ls $GRANITE_MODEL | grep -i llava
llava.clip
llava.projector
```
We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty:
```python
import os
import torch
MODEL_PATH = os.getenv("GRANITE_MODEL")
if not MODEL_PATH:
raise ValueError("env var GRANITE_MODEL is unset!")
encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip"))
projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector"))
assert len(encoder_tensors) > 0
assert len(projector_tensors) > 0
```
If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'multi_modal_projector.linear_1.bias'`, `'multi_modal_projector.linear_1.weight'`, `'multi_modal_projector.linear_2.bias'`, `'multi_modal_projector.linear_2.weight'`, `'image_newline'`) in the multimodal `projector_tensors`.
### 2. Creating the Visual Component GGUF
To create the GGUF for the visual components, we need to write a config for the visual encoder; make sure the config contains the correct `image_grid_pinpoints`
Note: we refer to this file as `$VISION_CONFIG` later on.
```json
{
"_name_or_path": "siglip-model",
"architectures": [
"SiglipVisionModel"
],
"image_grid_pinpoints": [
[384,768],
[384,1152],
[384,1536],
[384,1920],
[384,2304],
[384,2688],
[384,3072],
[384,3456],
[384,3840],
[768,384],
[768,768],
[768,1152],
[768,1536],
[768,1920],
[1152,384],
[1152,768],
[1152,1152],
[1536,384],
[1536,768],
[1920,384],
[1920,768],
[2304,384],
[2688,384],
[3072,384],
[3456,384],
[3840,384]
],
"mm_patch_merge_type": "spatial_unpad",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"layer_norm_eps": 1e-6,
"hidden_act": "gelu_pytorch_tanh",
"projection_dim": 0,
"vision_feature_layer": [-24, -20, -12, -1]
}
```
Create a new directory to hold the visual components, and copy the llava.clip/projector files, as well as the vision config into it.
```bash
$ ENCODER_PATH=$PWD/visual_encoder
$ mkdir $ENCODER_PATH
$ cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin
$ cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/
$ cp $VISION_CONFIG $ENCODER_PATH/config.json
```
At which point you should have something like this:
```bash
$ ls $ENCODER_PATH
config.json llava.projector pytorch_model.bin
```
Now convert the components to GGUF; Note that we also override the image mean/std dev to `[.5,.5,.5]` since we use the siglip visual encoder - in the transformers model, you can find these numbers in the [preprocessor_config.json](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview/blob/main/preprocessor_config.json).
```bash
$ python convert_image_encoder_to_gguf.py \
-m $ENCODER_PATH \
--llava-projector $ENCODER_PATH/llava.projector \
--output-dir $ENCODER_PATH \
--clip-model-is-vision \
--clip-model-is-siglip \
--image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
```
this will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the abs path of this file as the `$VISUAL_GGUF_PATH.`
### 3. Creating the LLM GGUF.
The granite vision model contains a granite LLM as its language model. For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted with the normal conversion path.
First, set the `LLM_EXPORT_PATH` to the path to export the `transformers` LLM to.
```
$ export LLM_EXPORT_PATH=$PWD/granite_vision_llm
```
```python
import os
import transformers
MODEL_PATH = os.getenv("GRANITE_MODEL")
if not MODEL_PATH:
raise ValueError("env var GRANITE_MODEL is unset!")
LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH")
if not MODEL_PATH:
raise ValueError("env var LLM_EXPORT_PATH is unset!")
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH)
# NOTE: granite vision support was added to transformers very recently (4.49);
# if you get size mismatches, your version is too old.
# If you are running with an older version, set `ignore_mismatched_sizes=True`
# as shown below; it won't be loaded correctly, but the LLM part of the model that
# we are exporting will be loaded correctly.
model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True)
tokenizer.save_pretrained(LLM_EXPORT_PATH)
model.language_model.save_pretrained(LLM_EXPORT_PATH)
```
Now you can convert the exported LLM to GGUF with the normal converter in the root of the llama cpp project.
```bash
$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm.gguf
...
$ python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH
```
### 4. Running the Model in Llama cpp
Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. Sample usage:
Note - the test image shown below can be found [here](https://github-production-user-asset-6210df.s3.amazonaws.com/10740300/415512792-d90d5562-8844-4f34-a0a5-77f62d5a58b5.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20250221%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250221T054145Z&X-Amz-Expires=300&X-Amz-Signature=86c60be490aa49ef7d53f25d6c973580a8273904fed11ed2453d0a38240ee40a&X-Amz-SignedHeaders=host).
```bash
$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \
--mmproj $VISUAL_GGUF_PATH \
--image cherry_blossom.jpg \
-c 16384 \
-p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\<image>\nWhat type of flowers are in this picture?\n<|assistant|>\n" \
--temp 0
```
Sample response: `The flowers in the picture are cherry blossoms, which are known for their delicate pink petals and are often associated with the beauty of spring.`

View File

@@ -101,8 +101,27 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
```
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
```python
import os
import transformers
model_path = ...
llm_export_path = ...
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
tokenizer.save_pretrained(llm_export_path)
model.language_model.save_pretrained(llm_export_path)
```
Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
## llava-cli templating and llava-1.6 prompting
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`

View File

@@ -40,6 +40,7 @@
#include <map>
#include <regex>
#include <stdexcept>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <cinttypes>
@@ -120,6 +121,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -444,8 +446,9 @@ struct clip_hparams {
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
int32_t image_grid_pinpoints[32];
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
};
struct clip_layer {
@@ -585,6 +588,7 @@ struct clip_ctx {
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
int32_t max_feature_layer;
float image_mean[3];
float image_std[3];
bool use_gelu = false;
@@ -651,7 +655,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
int n_layer = hparams.n_layer;
const float eps = hparams.eps;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
@@ -752,13 +755,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
}
std::vector<struct ggml_tensor *> embedding_stack;
const auto & vision_feature_layer = hparams.vision_feature_layer;
// loop over layers
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
for (int il = 0; il < n_layer - 1; il++) {
for (int il = 0; il < ctx->max_feature_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
// layernorm1
@@ -846,7 +855,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// post-layernorm
@@ -857,6 +865,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
// final layer is a vision feature layer
if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}
// If feature layers are explicitly set, stack them (if we have multiple)
if (!embedding_stack.empty()) {
embeddings = embedding_stack[0];
for (size_t i = 1; i < embedding_stack.size(); i++) {
embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
}
}
// llava projector
if (ctx->has_llava_projector) {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
@@ -1443,14 +1464,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
int n = gguf_get_arr_n(ctx, idx);
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
hparams.image_grid_pinpoints[i] = pinpoints[i];
for (int i = 0; i < n; ++i) {
hparams.image_grid_pinpoints.push_back(pinpoints[i]);
}
if (n < 32)
hparams.image_grid_pinpoints[n] = 0;
} catch (std::runtime_error & /*e*/) {
hparams.image_grid_pinpoints[0]=0;
}
} catch (std::runtime_error & /*e*/) { }
// Load the vision feature layer indices if they are explicitly provided;
// if multiple vision feature layers are present, the values will be concatenated
// to form the final visual features.
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
try {
int idx = get_key_idx(ctx, KEY_FEATURE_LAYER);
int n = gguf_get_arr_n(ctx, idx);
const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx);
for (int i = 0; i < n; ++i) {
hparams.vision_feature_layer.insert(vision_feature_layer[i]);
}
} catch (std::runtime_error & /*e*/) { }
try {
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
@@ -1476,6 +1509,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->image_std[i] = std_data[i];
}
// Calculate the deepest feature layer based on hparams and projector type
new_clip->max_feature_layer = get_deepest_feature_layer(new_clip);
if (verbosity >= 2) {
LOG_INF("\n%s: vision model hparams\n", __func__);
LOG_INF("image_size %d\n", hparams.image_size);
@@ -1489,8 +1525,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
LOG_INF("v_image_grid_pinpoints: ");
for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
for (const auto & pp : hparams.image_grid_pinpoints) {
LOG_INF("%d ", pp);
}
LOG_INF("\n");
LOG_INF("v_vision_feature_layer: ");
for (const auto & feature_layer: hparams.vision_feature_layer) {
LOG_INF("%d ", feature_layer);
}
LOG_INF("\n");
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
@@ -1729,11 +1770,11 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) {
}
}
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
img->nx = nx;
img->ny = ny;
img->buf.resize(3 * nx * ny);
memcpy(img->buf.data(), data, img->buf.size());
memcpy(img->buf.data(), rgb_pixels, img->buf.size());
}
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
@@ -1743,7 +1784,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
LOG_ERR("%s: failed to load image '%s'\n", __func__, fname);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
clip_build_img_from_pixels(data, nx, ny, img);
stbi_image_free(data);
return true;
}
@@ -1755,7 +1796,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
LOG_ERR("%s: failed to decode image bytes\n", __func__);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
clip_build_img_from_pixels(data, nx, ny, img);
stbi_image_free(data);
return true;
}
@@ -2235,10 +2276,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
}
}
} else {
if (params.image_grid_pinpoints[0] != 0) {
if (!params.image_grid_pinpoints.empty()) {
// "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<std::pair<int, int>> possible_resolutions;
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) {
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
}
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
@@ -2404,7 +2445,14 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
}
const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints;
if (ctx->vision_model.hparams.image_grid_pinpoints.size()) {
return &ctx->vision_model.hparams.image_grid_pinpoints.front();
}
return nullptr;
}
size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints.size();
}
int clip_n_patches(const struct clip_ctx * ctx) {
@@ -2929,6 +2977,28 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger;
}
// Determine the number of encoder layers to iterate over
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
// Get the index of the second to last layer; this is the
// default for models that have a llava projector
const auto & hparams = ctx->vision_model.hparams;
int n_layer = hparams.n_layer - 1;
int deepest_feature_layer = -1;
// Handle other projectors; incrementing here indicates that we
// should use the last encoder layer for the vision features.
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
// If we set explicit vision feature layers, only go up to the deepest one
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
}
return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;

View File

@@ -55,6 +55,7 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
@@ -73,6 +74,12 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
/**
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
*/
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
@@ -89,11 +96,13 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
#ifdef __cplusplus
}

View File

@@ -6,7 +6,7 @@ import re
import torch
import numpy as np
from gguf import *
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipVisionModel
TEXT = "clip.text"
VISION = "clip.vision"
@@ -37,6 +37,18 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
def get_tensor_name(name: str) -> str:
# Standardize the transformers llava next keys for
# image newline / mm projector with the classes in haotian-liu LLaVA
if name == "image_newline":
return "model.image_newline"
if name.startswith("multi_modal_projector"):
name = name.replace("multi_modal_projector", "mm")
if "linear_1" in name:
name = name.replace("linear_1", "0")
if "linear_2" in name:
name = name.replace("linear_2", "2")
return name
if "projection" in name:
return name
if "mm_projector" in name:
@@ -83,8 +95,14 @@ ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
# Selectable visual encoders that are compatible with this script
encoder_group = ap.add_mutually_exclusive_group()
encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False,
help="The clip model is from openclip (for ViT-SO400M type))")
encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False,
help="the visual encoder is Siglip.")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
@@ -109,7 +127,12 @@ if args.use_f32:
# output in the same directory as the model if output_dir is None
dir_model = args.model_dir
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
if (
args.clip_model_is_vision or
not os.path.exists(dir_model + "/vocab.json") or
args.clip_model_is_openclip or
args.clip_model_is_siglip
):
vocab = None
tokens = None
else:
@@ -137,7 +160,10 @@ ftype = 1
if args.use_f32:
ftype = 0
if args.clip_model_is_vision or args.clip_model_is_openclip:
if args.clip_model_is_siglip:
model = SiglipVisionModel.from_pretrained(dir_model)
processor = None
elif args.clip_model_is_vision or args.clip_model_is_openclip:
model = CLIPVisionModel.from_pretrained(dir_model)
processor = None
else:
@@ -187,26 +213,71 @@ else:
if has_text_encoder:
assert t_hparams is not None
assert tokens is not None
if args.clip_model_is_siglip:
text_projection_dim = 0
else:
text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"])
# text_model hparams
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
fout.add_uint32("clip.text.projection_dim", text_projection_dim)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
fout.add_token_list(tokens)
def get_non_negative_vision_feature_layers(v_hparams):
"""
Determine the vision feature layer(s) for the llava model, which are indices into the
hidden states of the visual encoder. Note that the hidden states array generally takes the
form:
[<emb input>, <output of enc block 0>, ... <output of enc block num_hidden_layers>]
so feature indices should be offset as n+1 to get the output of encoder block n.
We convert all vision feature layers to non-negative so that -1 can be used in
the model as an unset value. If no vision feature layer is found, we leave it unset.
"""
num_hidden_layers = v_hparams["num_hidden_layers"]
to_non_negative = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
feature_layers_key = None
# Key used for llava models in transformers
if "vision_feature_layer" in config:
feature_layers_key = "vision_feature_layer"
# Key used for llava models in the original format
elif "mm_vision_select_layer" in config:
feature_layers_key = "mm_vision_select_layer"
if feature_layers_key is not None:
feature_layers = config[feature_layers_key]
if isinstance(feature_layers, int):
feature_layers = [feature_layers]
return [to_non_negative(feature_layer) for feature_layer in feature_layers]
# Determine if we have explicitly specified vision feature layers in our config
feature_layers = get_non_negative_vision_feature_layers(v_hparams)
if has_vision_encoder:
# vision_model hparams
# Siglip does not have a visual projector; set projection dim to 0
if args.clip_model_is_siglip:
visual_projection_dim = 0
else:
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
# set vision_model hparams
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"]))
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
if feature_layers:
block_count = max(feature_layers)
else:
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
# /**
# "image_grid_pinpoints": [
@@ -258,7 +329,8 @@ if has_vision_encoder:
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
if "mm_projector_type" in v_hparams:
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
if feature_layers:
fout.add_array("clip.vision.feature_layer", feature_layers)
if processor is not None:
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
@@ -274,7 +346,13 @@ fout.add_bool("clip.use_gelu", use_gelu)
if has_llava_projector:
model.vision_model.encoder.layers.pop(-1)
# By default, we drop the last layer for llava projector
# models unless we have explicitly set vision feature layers
if feature_layers is None:
model.vision_model.encoder.layers.pop(-1)
else:
model.vision_model.encoder.layers = model.vision_model.encoder.layers[:max(feature_layers)]
projector = torch.load(args.llava_projector)
for name, data in projector.items():
name = get_tensor_name(name)

View File

@@ -353,9 +353,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
const int32_t * image_grid = clip_image_grid(ctx_clip);
const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
std::vector<std::pair<int, int>> grid_pinpoints;
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
for (size_t i = 0; i < num_gridpoints; i += 2) {
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
}
@@ -405,7 +406,8 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx *
}
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
int num_max_patches = 6;
// Granite vision uses up to 10 patches + base patch
int num_max_patches = 11;
if (clip_is_minicpmv(ctx_clip)) {
num_max_patches = 10;
}

View File

@@ -33,6 +33,33 @@ def save_model(model, file_path, file_type):
else:
torch.save(model, file_path)
# Helpers to match weight names from specific components or
# determine if a saved shard contains that component
def is_vision_tower(weight_name):
return (
weight_name.startswith("model.vision_tower") or
weight_name.startswith("vit.") or
weight_name.startswith("vision_tower")
)
def is_newline(weight_name):
return (
weight_name.startswith("model.image_newline") or
weight_name.startswith("image_newline")
)
def is_mm_projector(weight_name):
return (
weight_name.startswith("model.mm_projector") or
weight_name.startswith("vision_proj.") or
weight_name.startswith("multi_modal_projector")
)
def newline_criteria(checkpoint):
return any(is_newline(k) for k in checkpoint.keys())
def proj_criteria(checkpoint):
return any(is_mm_projector(k) for k in checkpoint.keys())
# Adapted function to clean vision tower from checkpoint
def clean_vision_tower_from_checkpoint(checkpoint_path):
@@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
# file_type = 'pytorch'
model_path = os.path.dirname(checkpoint_path)
print(f"Searching for vision tower tensors in {checkpoint_path}")
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
if len(clip_tensors) > 0:
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
@@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
return newline_checkpoint_path, projector_checkpoint_path
def newline_criteria(checkpoint):
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
def proj_criteria(checkpoint):
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
# Command-line interface setup
ap = argparse.ArgumentParser()
@@ -123,14 +144,14 @@ first_checkpoint = None
if newline_checkpoint_path is not None:
print(f"Taking newline from {newline_checkpoint_path}")
first_checkpoint, file_type = load_model(newline_checkpoint_path)
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
# Load the checkpoint
mm_tensors = []
last_checkpoint = None
if projector_checkpoint_path is not None:
last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
if len(mm_tensors) == 0:
if last_checkpoint is not None:
@@ -155,5 +176,5 @@ if len(projector) > 0:
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
print("Done!")
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")

View File

@@ -323,25 +323,17 @@ class File {
return 0;
}
std::string read_all(const std::string & filename){
open(filename, "r");
lock();
if (!file) {
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
return "";
}
std::string to_string() {
fseek(file, 0, SEEK_END);
size_t size = ftell(file);
const size_t size = ftell(file);
fseek(file, 0, SEEK_SET);
std::string out;
out.resize(size);
size_t read_size = fread(&out[0], 1, size, file);
const size_t read_size = fread(&out[0], 1, size, file);
if (read_size != size) {
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
return "";
printe("Error reading file: %s", strerror(errno));
}
return out;
}
@@ -985,7 +977,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
}
static int read_user_input(std::string & user_input) {
static const char * prompt_prefix = "> ";
static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX");
static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> ";
#ifdef WIN32
printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
@@ -1098,59 +1091,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {
// Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) {
if(chat_template_file.empty()){
return "";
}
File file;
std::string chat_template = "";
chat_template = file.read_all(chat_template_file);
if(chat_template.empty()){
if (!file.open(chat_template_file, "r")) {
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return "";
}
return chat_template;
return file.to_string();
}
static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
const common_chat_templates_ptr & chat_templates, int & prev_len,
const bool stdout_a_terminal) {
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
return 1;
}
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
return 1;
}
if (!opt.user.empty()) {
return 2;
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
return 1;
}
return 0;
}
// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
std::string chat_template = "";
if(!chat_template_file.empty()){
chat_template = read_chat_template_file(chat_template_file);
std::string chat_template;
if (!opt.chat_template_file.empty()) {
chat_template = read_chat_template_file(opt.chat_template_file);
}
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
std::string user_input;
if (get_user_input(user_input, user) == 1) {
if (get_user_input(user_input, opt.user) == 1) {
return 0;
}
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
if (ret == 1) {
return 1;
}
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
return 1;
}
if (!user.empty()) {
} else if (ret == 2) {
break;
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}
return 0;
@@ -1208,7 +1208,7 @@ int main(int argc, const char ** argv) {
return 1;
}
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
if (chat_loop(llama_data, opt)) {
return 1;
}

View File

@@ -13,6 +13,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
* Multimodal (wip)
* Monitoring endpoints
* Schema-constrained JSON response format
* [Function calling](../../docs/function-calling.md) / tool use for ~any model
The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggml-org/llama.cpp/issues/4216).
@@ -1120,381 +1121,9 @@ curl http://localhost:8080/v1/chat/completions \
*Tool call support*
[Function calling](https://platform.openai.com/docs/guides/function-calling) is supported for all models (see https://github.com/ggml-org/llama.cpp/pull/9639):
[OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) is supported with the `--jinja` flag (and may require a `--chat-template-file` override to get the right tool-use compatible Jinja template; worst case, `--chat-template chatml` may also work).
- Requires `--jinja` flag
- Native tool call formats supported:
- Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2
- Functionary v3.1 / v3.2
- Hermes 2/3, Qwen 2.5
- Mistral Nemo
- Firefunction v2
- Command R7B
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
<details>
<summary>Show some common templates and which format handler they use</summary>
| Template | Format |
|----------|--------|
| Almawave-Velvet-14B.jinja | Hermes 2 Pro |
| AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x |
| CohereForAI-aya-expanse-8b.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-default.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-rag.jinja | Generic |
| CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic |
| CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) |
| CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic |
| DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic |
| Delta-Vector-Rei-12B.jinja | Mistral Nemo |
| EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo |
| FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro |
| FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic |
| HelpingAI-HAI-SER.jinja | Generic |
| HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic |
| HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic |
| HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic |
| INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic |
| Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro |
| Infinigence-Megrez-3B-Instruct.jinja | Generic |
| Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic |
| LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic |
| LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic |
| LatitudeGames-Wayfarer-12B.jinja | Generic |
| Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic |
| Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic |
| MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic |
| MiniMaxAI-MiniMax-Text-01.jinja | Generic |
| MiniMaxAI-MiniMax-VL-01.jinja | Generic |
| NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) |
| NexaAIDev-Octopus-v2.jinja | Generic |
| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic |
| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro |
| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic |
| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro |
| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic |
| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro |
| NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro |
| NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro |
| OnlyCheeini-greesychat-turbo.jinja | Generic |
| Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x |
| OrionStarAI-Orion-14B-Chat.jinja | Generic |
| PowerInfer-SmallThinker-3B-Preview.jinja | Generic |
| PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic |
| Qwen-QVQ-72B-Preview.jinja | Generic |
| Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro |
| Qwen-Qwen1.5-7B-Chat.jinja | Generic |
| Qwen-Qwen2-7B-Instruct.jinja | Generic |
| Qwen-Qwen2-VL-72B-Instruct.jinja | Generic |
| Qwen-Qwen2-VL-7B-Instruct.jinja | Generic |
| Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro |
| Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro |
| RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro |
| SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro |
| SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro |
| Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x |
| SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x |
| SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x |
| Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x |
| Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x |
| Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x |
| THUDM-glm-4-9b-chat.jinja | Generic |
| THUDM-glm-edge-1.5b-chat.jinja | Generic |
| Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x |
| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic |
| TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic |
| UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic |
| ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x |
| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic |
| ai21labs-AI21-Jamba-1.5-Large.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-405B.jinja | Generic |
| allenai-Llama-3.1-Tulu-3-8B.jinja | Generic |
| arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro |
| arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro |
| arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro |
| avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic |
| bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro |
| bfuzzy1-acheron-m1a-llama.jinja | Generic |
| bofenghuang-vigogne-2-70b-chat.jinja | Generic |
| bytedance-research-UI-TARS-72B-DPO.jinja | Generic |
| bytedance-research-UI-TARS-7B-DPO.jinja | Generic |
| bytedance-research-UI-TARS-7B-SFT.jinja | Generic |
| carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic |
| cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| databricks-dbrx-instruct.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic |
| deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic |
| deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-V2-Lite.jinja | Generic |
| deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) |
| deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic |
| deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic |
| deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic |
| deepseek-ai-deepseek-llm-67b-chat.jinja | Generic |
| deepseek-ai-deepseek-llm-7b-chat.jinja | Generic |
| dicta-il-dictalm2.0-instruct.jinja | Generic |
| ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro |
| fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 |
| godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro |
| godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro |
| google-gemma-2-27b-it.jinja | Generic |
| google-gemma-2-2b-it.jinja | Generic |
| google-gemma-2-2b-jpn-it.jinja | Generic |
| google-gemma-7b-it.jinja | Generic |
| huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) |
| huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro |
| ibm-granite-granite-3.1-8b-instruct.jinja | Generic |
| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic |
| inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic |
| jinaai-ReaderLM-v2.jinja | Generic |
| kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro |
| knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo |
| langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic |
| lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) |
| mattshumer-Reflection-Llama-3.1-70B.jinja | Generic |
| meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 |
| meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 |
| meta-llama-Llama-2-7b-chat-hf.jinja | Generic |
| meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x |
| meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x |
| meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic |
| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x |
| microsoft-Phi-3-medium-4k-instruct.jinja | Generic |
| microsoft-Phi-3-mini-4k-instruct.jinja | Generic |
| microsoft-Phi-3-small-8k-instruct.jinja | Generic |
| microsoft-Phi-3.5-mini-instruct.jinja | Generic |
| microsoft-Phi-3.5-vision-instruct.jinja | Generic |
| microsoft-phi-4.jinja | Generic |
| migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic |
| ministral-Ministral-3b-instruct.jinja | Generic |
| mistralai-Codestral-22B-v0.1.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic |
| mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo |
| mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo |
| mistralai-Mistral-Large-Instruct-2411.jinja | Generic |
| mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo |
| mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic |
| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic |
| mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| mlabonne-AlphaMonarch-7B.jinja | Generic |
| mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro |
| mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro |
| mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) |
| netcat420-MFANNv0.20.jinja | Generic |
| netcat420-MFANNv0.24.jinja | Generic |
| netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro |
| nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro |
| nvidia-Eagle2-1B.jinja | Hermes 2 Pro |
| nvidia-Eagle2-9B.jinja | Hermes 2 Pro |
| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x |
| onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) |
| open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro |
| openchat-openchat-3.5-0106.jinja | Generic |
| pankajmathur-orca_mini_v6_8b.jinja | Generic |
| princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic |
| princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic |
| princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic |
| prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro |
| prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x |
| prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic |
| prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x |
| prithivMLmods-Blaze-14B-xElite.jinja | Generic |
| prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro |
| prithivMLmods-Calme-Ties-78B.jinja | Generic |
| prithivMLmods-Calme-Ties2-78B.jinja | Generic |
| prithivMLmods-Calme-Ties3-78B.jinja | Generic |
| prithivMLmods-ChemQwen2-vL.jinja | Generic |
| prithivMLmods-GWQ2b.jinja | Generic |
| prithivMLmods-LatexMind-2B-Codec.jinja | Generic |
| prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x |
| prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro |
| prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro |
| prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro |
| prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro |
| prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro |
| prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro |
| prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) |
| prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro |
| prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro |
| prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro |
| qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro |
| rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro |
| rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro |
| silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic |
| simplescaling-s1-32B.jinja | Hermes 2 Pro |
| sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro |
| sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic |
| sthenno-tempesthenno-icy-0130.jinja | Generic |
| sumink-qwft.jinja | Hermes 2 Pro |
| teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic |
| thirdeyeai-elevate360m.jinja | Generic |
| tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro |
| unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) |
| unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic |
| upstage-solar-pro-preview-instruct.jinja | Generic |
| whyhow-ai-PatientSeek.jinja | Generic |
| xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro |
| xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro |
This table can be generated with:
```bash
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
```
</details>
- Generic tool call is supported when the template isn't recognized by native format handlers (you'll see `Chat format: Generic` in the logs).
- Use `--chat-template-file` to override the template when appropriate (see examples below)
- Generic support may consume more tokens and be less efficient than a model's native format.
- Run with:
```shell
# Native support:
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
# Native support for DeepSeek R1 works best w/ our own template (official template buggy)
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \
--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
# Native support requires the right template for these GGUFs:
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
```
- Test in CLI:
```bash
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"tools": [
{
"type":"function",
"function":{
"name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{
"type":"object",
"properties":{
"code":{
"type":"string",
"description":"The code to run in the ipython interpreter."
}
},
"required":["code"]
}
}
}
],
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
}
]
}'
```
<details>
<summary>Show output</summary>
```json
{
"choices": [
{
"finish_reason": "tool",
"index": 0,
"message": {
"content": null,
"tool_calls": [
{
"name": "python",
"arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}"
}
],
"role": "assistant"
}
}
],
"created": 1727287211,
"model": "gpt-3.5-turbo",
"object": "chat.completion",
"usage": {
"completion_tokens": 16,
"prompt_tokens": 44,
"total_tokens": 60
},
"id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8"
}
```
</details>
**See our [Function calling](../../docs/function-calling.md) docs** for more details, supported native tool call styles (generic tool call style is used as fallback) / examples of use.
### POST `/v1/embeddings`: OpenAI-compatible embeddings API

View File

@@ -7,6 +7,8 @@
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
// disable Nagle's algorithm
#define CPPHTTPLIB_TCP_NODELAY true
#include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
@@ -519,8 +521,13 @@ static json oaicompat_completion_params_parse(const json & body) {
throw std::runtime_error("Only one completion choice is allowed");
}
// Handle "echo" field
if (json_value(body, "echo", false)) {
throw std::runtime_error("Only no echo is supported");
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
@@ -596,7 +603,7 @@ static json oaicompat_completion_params_parse(
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.add_generation_prompt = true;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.use_jinja = use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;

View File

@@ -3,7 +3,7 @@
# MIT license
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
@@ -13,7 +13,7 @@ source /opt/intel/oneapi/setvars.sh
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=33
CONEXT=8192
CONEXT=4096
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1