opencl: add attn sinks support for FA kernels (#15706)

This commit is contained in:
rmatif
2025-09-02 08:26:53 +02:00
committed by GitHub
parent 2f853687b3
commit 97669e4073
4 changed files with 102 additions and 16 deletions

View File

@@ -49,7 +49,9 @@ __kernel void flash_attn_f16(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f16(
}
if (my_query_row < n_q) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
}
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f16_q1(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f16_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY;
const global ACC_TYPE* sinks_ptr = NULL;
if (sinks_void != NULL) {
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f16_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0];
ACC_TYPE l_final = local_l[0];
if (sinks_ptr != NULL) {
l_final += exp(sinks_ptr[head_idx] - m_final);
}
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;

View File

@@ -49,7 +49,9 @@ __kernel void flash_attn_f32(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f32(
}
if (my_query_row < n_q) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
}
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f32_q1(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f32_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY;
const global ACC_TYPE* sinks_ptr = NULL;
if (sinks_void != NULL) {
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f32_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0];
ACC_TYPE l_final = local_l[0];
if (sinks_ptr != NULL) {
l_final += exp(sinks_ptr[head_idx] - m_final);
}
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;

View File

@@ -52,7 +52,9 @@ __kernel void flash_attn_f32_f16(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
@@ -174,6 +176,20 @@ __kernel void flash_attn_f32_f16(
}
if (my_query_row < n_q) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
}
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
@@ -217,7 +233,9 @@ __kernel void flash_attn_f32_f16_q1(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
@@ -250,7 +268,12 @@ __kernel void flash_attn_f32_f16_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY;
const global ACC_TYPE* sinks_ptr = NULL;
if (sinks_void != NULL) {
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
@@ -323,7 +346,11 @@ __kernel void flash_attn_f32_f16_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0];
ACC_TYPE l_final = local_l[0];
if (sinks_ptr != NULL) {
l_final += exp(sinks_ptr[head_idx] - m_final);
}
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;