mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
opencl: fix FA for f32 (#16584)
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
#define ACC_TYPE4 float4
|
#define ACC_TYPE4 float4
|
||||||
#define DATA_TYPE float
|
#define DATA_TYPE float
|
||||||
#define DATA_TYPE4 float4
|
#define DATA_TYPE4 float4
|
||||||
|
#define MASK_DATA_TYPE half
|
||||||
#define CONVERT_ACC4(x) (x)
|
#define CONVERT_ACC4(x) (x)
|
||||||
#define CONVERT_DATA4(x) (x)
|
#define CONVERT_DATA4(x) (x)
|
||||||
|
|
||||||
@@ -148,7 +149,7 @@ __kernel void flash_attn_f32(
|
|||||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||||
|
|
||||||
if (mask_base != NULL) {
|
if (mask_base != NULL) {
|
||||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||||
}
|
}
|
||||||
@@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1(
|
|||||||
}
|
}
|
||||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||||
if (mask_base != NULL) {
|
if (mask_base != NULL) {
|
||||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
||||||
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
||||||
}
|
}
|
||||||
if (logit_softcap > 0.0f) {
|
if (logit_softcap > 0.0f) {
|
||||||
@@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1(
|
|||||||
}
|
}
|
||||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||||
if (mask_base != NULL) {
|
if (mask_base != NULL) {
|
||||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
||||||
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
||||||
}
|
}
|
||||||
if (logit_softcap > 0.0f) {
|
if (logit_softcap > 0.0f) {
|
||||||
|
|||||||
Reference in New Issue
Block a user