mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	CPU/CUDA: Gemma 2 FlashAttention support (#8542)
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
This commit is contained in:
		@@ -7095,7 +7095,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
 | 
			
		||||
        struct ggml_tensor  * v,
 | 
			
		||||
        struct ggml_tensor  * mask,
 | 
			
		||||
        float                 scale,
 | 
			
		||||
        float                 max_bias) {
 | 
			
		||||
        float                 max_bias,
 | 
			
		||||
        float                 logit_softcap) {
 | 
			
		||||
    GGML_ASSERT(ggml_can_mul_mat(k, q));
 | 
			
		||||
    // TODO: check if vT can be multiplied by (k*qT)
 | 
			
		||||
 | 
			
		||||
@@ -7122,7 +7123,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
 | 
			
		||||
    int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
 | 
			
		||||
    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 | 
			
		||||
 | 
			
		||||
    float params[] = { scale, max_bias };
 | 
			
		||||
    float params[] = { scale, max_bias, logit_softcap };
 | 
			
		||||
    ggml_set_op_params(result, params, sizeof(params));
 | 
			
		||||
 | 
			
		||||
    result->op   = GGML_OP_FLASH_ATTN_EXT;
 | 
			
		||||
@@ -7142,7 +7143,7 @@ void ggml_flash_attn_ext_set_prec(
 | 
			
		||||
 | 
			
		||||
    const int32_t prec_i32 = (int32_t) prec;
 | 
			
		||||
 | 
			
		||||
    ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
 | 
			
		||||
    ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ggml_flash_attn_back
 | 
			
		||||
@@ -15271,11 +15272,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 | 
			
		||||
    const int ir0 = dr*ith;
 | 
			
		||||
    const int ir1 = MIN(ir0 + dr, nr);
 | 
			
		||||
 | 
			
		||||
    float scale    = 1.0f;
 | 
			
		||||
    float max_bias = 0.0f;
 | 
			
		||||
    float scale         = 1.0f;
 | 
			
		||||
    float max_bias      = 0.0f;
 | 
			
		||||
    float logit_softcap = 0.0f;
 | 
			
		||||
 | 
			
		||||
    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
 | 
			
		||||
    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 | 
			
		||||
    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
 | 
			
		||||
    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
 | 
			
		||||
    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
 | 
			
		||||
 | 
			
		||||
    if (logit_softcap != 0) {
 | 
			
		||||
        scale /= logit_softcap;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const uint32_t n_head      = neq2;
 | 
			
		||||
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
 | 
			
		||||
@@ -15339,7 +15346,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 | 
			
		||||
            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
 | 
			
		||||
            kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
 | 
			
		||||
 | 
			
		||||
            s = s*scale + mv; // scale KQ value and apply mask
 | 
			
		||||
            s = s*scale; // scale KQ value
 | 
			
		||||
 | 
			
		||||
            if (logit_softcap != 0.0f) {
 | 
			
		||||
                s = logit_softcap*tanhf(s);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            s += mv; // apply mask
 | 
			
		||||
 | 
			
		||||
            const float Mold = M;
 | 
			
		||||
 | 
			
		||||
@@ -15348,7 +15361,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 | 
			
		||||
 | 
			
		||||
            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
 | 
			
		||||
 | 
			
		||||
            if (v->type== GGML_TYPE_F16) {
 | 
			
		||||
            if (v->type == GGML_TYPE_F16) {
 | 
			
		||||
                if (s > M) {
 | 
			
		||||
                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
 | 
			
		||||
                    M = s;
 | 
			
		||||
@@ -15415,7 +15428,7 @@ static void ggml_compute_forward_flash_attn_ext(
 | 
			
		||||
        const struct ggml_tensor * v,
 | 
			
		||||
        const struct ggml_tensor * mask,
 | 
			
		||||
        struct ggml_tensor * dst) {
 | 
			
		||||
    switch (dst->op_params[2]) {
 | 
			
		||||
    switch (dst->op_params[3]) {
 | 
			
		||||
        case GGML_PREC_DEFAULT:
 | 
			
		||||
        case GGML_PREC_F32:
 | 
			
		||||
            {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user