opencl: fix FA for f32 (#16584)

This commit is contained in:
lhez
2025-10-15 10:48:28 -07:00
committed by GitHub
parent f9fb33f263
commit d93f8439b0

View File

@@ -4,6 +4,7 @@
#define ACC_TYPE4 float4
#define DATA_TYPE float
#define DATA_TYPE4 float4
#define MASK_DATA_TYPE half
#define CONVERT_ACC4(x) (x)
#define CONVERT_DATA4(x) (x)
@@ -148,7 +149,7 @@ __kernel void flash_attn_f32(
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
@@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1(
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
@@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1(
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {