clip : use FA (#16837)

* clip : use FA

* cont : add warning about unsupported ops

* implement "auto" mode for clip flash attn

* clip : print more detailed op support info during warmup

* cont : remove obsolete comment [no ci]

* improve debugging message

* trailing space

* metal : remove stray return

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Georgi Gerganov
2025-11-02 22:21:48 +02:00
committed by GitHub
parent cd5e3b5754
commit 2f966b8ed8
9 changed files with 194 additions and 43 deletions

View File

@@ -19,7 +19,6 @@
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <vector>
// represents raw image data, layout is RGBRGBRGB...
@@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
return "<__media__>";
}
static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
switch (flash_attn_type) {
case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
}
return CLIP_FLASH_ATTN_TYPE_AUTO;
}
mtmd_context_params mtmd_context_params_default() {
mtmd_context_params params;
params.use_gpu = true;
@@ -100,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() {
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
return params;
}
@@ -164,6 +173,7 @@ struct mtmd_context {
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
auto res = clip_init(mmproj_fname, ctx_clip_params);
ctx_v = res.ctx_v;
ctx_a = res.ctx_a;
@@ -378,9 +388,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
}
void mtmd_free(mtmd_context * ctx) {
if (ctx) {
delete ctx;
}
delete ctx;
}
struct mtmd_tokenizer {