mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml : add ggml_flash_attn_ext_get_prec
This commit is contained in:
		@@ -1746,6 +1746,9 @@ extern "C" {
 | 
			
		||||
            struct ggml_tensor * a,
 | 
			
		||||
            enum ggml_prec       prec);
 | 
			
		||||
 | 
			
		||||
    GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
 | 
			
		||||
            const struct ggml_tensor * a);
 | 
			
		||||
 | 
			
		||||
    // TODO: needs to be adapted to ggml_flash_attn_ext
 | 
			
		||||
    GGML_API struct ggml_tensor * ggml_flash_attn_back(
 | 
			
		||||
           struct ggml_context * ctx,
 | 
			
		||||
 
 | 
			
		||||
@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
 | 
			
		||||
    const ggml_tensor * KQV = dst;
 | 
			
		||||
    const ggml_tensor * Q   = dst->src[0];
 | 
			
		||||
 | 
			
		||||
    const int32_t precision = KQV->op_params[3];
 | 
			
		||||
    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 | 
			
		||||
 | 
			
		||||
    if (precision != GGML_PREC_DEFAULT) {
 | 
			
		||||
    if (prec != GGML_PREC_DEFAULT) {
 | 
			
		||||
        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
 | 
			
		||||
            constexpr int cols_per_block = 16;
 | 
			
		||||
            switch (Q->ne[0]) {
 | 
			
		||||
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 | 
			
		||||
 | 
			
		||||
    ggml_cuda_set_device(ctx.device);
 | 
			
		||||
    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 | 
			
		||||
    const int32_t precision = KQV->op_params[3];
 | 
			
		||||
    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 | 
			
		||||
 | 
			
		||||
    // On AMD the tile kernels perform poorly, use the vec kernel instead:
 | 
			
		||||
    if (cc >= CC_OFFSET_AMD) {
 | 
			
		||||
        if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
 | 
			
		||||
        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
 | 
			
		||||
            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
 | 
			
		||||
        } else {
 | 
			
		||||
            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
 | 
			
		||||
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
 | 
			
		||||
        if (precision == GGML_PREC_DEFAULT) {
 | 
			
		||||
        if (prec == GGML_PREC_DEFAULT) {
 | 
			
		||||
            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
 | 
			
		||||
            return;
 | 
			
		||||
        } else if(Q->ne[0] <= 128) {
 | 
			
		||||
 
 | 
			
		||||
@@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
 | 
			
		||||
    ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum ggml_prec ggml_flash_attn_ext_get_prec(
 | 
			
		||||
        const struct ggml_tensor * a) {
 | 
			
		||||
    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
 | 
			
		||||
 | 
			
		||||
    const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
 | 
			
		||||
 | 
			
		||||
    return (enum ggml_prec) prec_i32;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ggml_flash_attn_back
 | 
			
		||||
 | 
			
		||||
struct ggml_tensor * ggml_flash_attn_back(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user