mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	@@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16(
 | 
				
			|||||||
        // pointer to the mask
 | 
					        // pointer to the mask
 | 
				
			||||||
        device const half * mp = (device const half *) (mask + iq1*nb31);
 | 
					        device const half * mp = (device const half *) (mask + iq1*nb31);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // prepare diagonal scale matrix
 | 
					        float slope = 1.0f;
 | 
				
			||||||
        simdgroup_float8x8 mscale(scale);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // prepare diagonal slope matrix
 | 
					 | 
				
			||||||
        simdgroup_float8x8 mslope(1.0f);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // ALiBi
 | 
					        // ALiBi
 | 
				
			||||||
        if (max_bias > 0.0f) {
 | 
					        if (max_bias > 0.0f) {
 | 
				
			||||||
@@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16(
 | 
				
			|||||||
            const float base = h < n_head_log2 ? m0 : m1;
 | 
					            const float base = h < n_head_log2 ? m0 : m1;
 | 
				
			||||||
            const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 | 
					            const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            mslope = simdgroup_float8x8(pow(base, exph));
 | 
					            slope = pow(base, exph);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // loop over the KV cache
 | 
					        // loop over the KV cache
 | 
				
			||||||
@@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16(
 | 
				
			|||||||
                        simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
 | 
					                        simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    const short tx = tiisg%4;
 | 
				
			||||||
 | 
					                    const short ty = tiisg/4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if (mask != q) {
 | 
					                    if (mask != q) {
 | 
				
			||||||
                        // mqk = mqk*scale + mask*slope
 | 
					                        // mqk = mqk*scale + mask*slope
 | 
				
			||||||
                        simdgroup_half8x8 mm;
 | 
					                        ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
 | 
				
			||||||
                        simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
 | 
					                        ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
 | 
				
			||||||
                        simdgroup_multiply(mm, mslope, mm);
 | 
					 | 
				
			||||||
                        simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
 | 
					 | 
				
			||||||
                    } else {
 | 
					                    } else {
 | 
				
			||||||
                        // mqk = mqk*scale
 | 
					                        // mqk = mqk*scale
 | 
				
			||||||
                        simdgroup_multiply(mqk, mscale, mqk);
 | 
					                        ss[8*cc + ty*TF + 2*tx + 0] *= scale;
 | 
				
			||||||
 | 
					                        ss[8*cc + ty*TF + 2*tx + 1] *= scale;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
                    simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
 | 
					 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16(
 | 
				
			|||||||
    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
 | 
					    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
 | 
				
			||||||
        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
 | 
					        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // TODO: is there a better way to handle -INFINITY?
 | 
					        dst_data[i00] = src[0];
 | 
				
			||||||
        dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user