clip : use FA (#16837)

* clip : use FA

* cont : add warning about unsupported ops

* implement "auto" mode for clip flash attn

* clip : print more detailed op support info during warmup

* cont : remove obsolete comment [no ci]

* improve debugging message

* trailing space

* metal : remove stray return

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Georgi Gerganov
2025-11-02 22:21:48 +02:00
committed by GitHub
parent cd5e3b5754
commit 2f966b8ed8
9 changed files with 194 additions and 43 deletions

View File

@@ -6,7 +6,6 @@
#include "clip-impl.h"
#include "ggml.h"
#include "ggml-cpp.h"
#include "ggml-cpu.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "gguf.h"
@@ -17,17 +16,15 @@
#include <cstring>
#include <fstream>
#include <map>
#include <regex>
#include <stdexcept>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <cinttypes>
#include <limits>
#include <array>
#include <numeric>
#include <functional>
// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
enum ffn_op_type {
@@ -426,12 +423,14 @@ struct clip_ctx {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
// for debugging
bool debug_graph = false;
std::vector<ggml_tensor *> debug_print_tensors;
clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
@@ -2260,17 +2259,25 @@ private:
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);
//cb(k, "v", il);
ggml_tensor * cur;
// TODO @ngxson : support flash attention
{
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
// const auto n_kv = k->ne[1]; // for flash attention
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// F32 may not needed for vision encoders?
@@ -3192,7 +3199,87 @@ struct clip_model_loader {
}
}
void alloc_compute_meta(clip_ctx & ctx_clip) {
struct support_info_op {
ggml_tensor * op;
// true if the op runs on the accelerated ctx_clip.backend
bool is_accel = true;
};
struct support_info_graph {
// whether the clip_ctx.backend supports flash attention
bool fattn = true;
ggml_tensor * fattn_op = nullptr; // for debugging
std::vector<support_info_op> ops;
};
static void warmup(clip_ctx & ctx_clip) {
support_info_graph info;
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
info = alloc_compute_meta(ctx_clip);
if (!info.fattn && info.fattn_op) {
auto op = info.fattn_op;
LOG_WRN("%s: *****************************************************************\n", __func__);
LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend));
LOG_WRN("%s: op params: \n", __func__);
static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) {
LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn,
name, ggml_type_name(t->type),
t->ne[0], t->ne[1], t->ne[2], t->ne[3],
t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
};
print_shape(__func__, " dst", op);
print_shape(__func__, "src0", op->src[0]);
print_shape(__func__, "src1", op->src[1]);
print_shape(__func__, "src2", op->src[2]);
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
alloc_compute_meta(ctx_clip);
}
} else {
info = alloc_compute_meta(ctx_clip);
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
}
}
LOG_INF("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
// print ops that are not supported by the GPU backend (if there is one)
if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) {
std::vector<support_info_op> unsupported_ops;
for (const auto & op : info.ops) {
if (!op.is_accel) {
unsupported_ops.push_back(op);
}
}
if (!unsupported_ops.empty()) {
LOG_WRN("%s: *****************************************************************\n", __func__);
LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
LOG_WRN("%s: the performance will be suboptimal \n", __func__);
LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend));
for (const auto & op : unsupported_ops) {
LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__,
ggml_op_name(op.op->op),
ggml_type_name(op.op->type),
op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
}
LOG_WRN("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
}
}
}
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
const auto & hparams = ctx_clip.model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
@@ -3223,57 +3310,95 @@ struct clip_model_loader {
size / 1024.0 / 1024.0);
}
}
const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
const int n_nodes = ggml_graph_n_nodes(gf);
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
support_info_graph res {
/*.fattn = */ true,
/*.fattn_op = */ nullptr,
/*.ops = */ {},
};
// check op support
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * node = ggml_graph_node(gf, i);
res.ops.push_back({node, true});
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
res.ops.back().is_accel = false;
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
res.fattn = false;
res.fattn_op = node;
}
}
}
return res;
}
void get_bool(const std::string & key, bool & output, bool required = true) {
void get_bool(const std::string & key, bool & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_bool(ctx_gguf.get(), i);
}
void get_i32(const std::string & key, int & output, bool required = true) {
void get_i32(const std::string & key, int & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_i32(ctx_gguf.get(), i);
}
void get_u32(const std::string & key, int & output, bool required = true) {
void get_u32(const std::string & key, int & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_u32(ctx_gguf.get(), i);
}
void get_f32(const std::string & key, float & output, bool required = true) {
void get_f32(const std::string & key, float & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_f32(ctx_gguf.get(), i);
}
void get_string(const std::string & key, std::string & output, bool required = true) {
void get_string(const std::string & key, std::string & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
}
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
int n = gguf_get_arr_n(ctx_gguf.get(), i);
@@ -3284,7 +3409,7 @@ struct clip_model_loader {
}
}
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
auto & hparams = model.hparams;
for (int x = 1; x <= max_patches_per_side; x++) {
for (int y = 1; y <= max_patches_per_side; y++) {
@@ -3312,24 +3437,22 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
ctx_vision = new clip_ctx(ctx_params);
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
loader.load_tensors(*ctx_vision);
loader.alloc_compute_meta(*ctx_vision);
loader.warmup(*ctx_vision);
}
if (loader.has_audio) {
ctx_audio = new clip_ctx(ctx_params);
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
loader.load_tensors(*ctx_audio);
loader.alloc_compute_meta(*ctx_audio);
loader.warmup(*ctx_audio);
}
} catch (const std::exception & e) {
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
if (ctx_vision) {
delete ctx_vision;
}
if (ctx_audio) {
delete ctx_audio;
}
delete ctx_vision;
delete ctx_audio;
return {nullptr, nullptr};
}
@@ -3367,10 +3490,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
}
delete load_image_size;
}
void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
return batch->entries.size();