mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
CUDA: refactor and deduplicate vector FA kernels (#16208)
* CUDA: refactor and deduplicate vector FA kernels
This commit is contained in:
@@ -586,17 +586,42 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
|
||||
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
acc += v*u;
|
||||
#else
|
||||
const float2 tmpv = __half22float2(v);
|
||||
const float2 tmpu = __half22float2(u);
|
||||
float2 tmpacc = __half22float2(acc);
|
||||
tmpacc.x += tmpv.x * tmpu.x;
|
||||
tmpacc.y += tmpv.y * tmpu.y;
|
||||
acc = make_half2(tmpacc.x, tmpacc.y);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
||||
template <int nbytes>
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
if constexpr (nbytes == 4) {
|
||||
*(int *) dst = *(const int *) src;
|
||||
} else if constexpr (nbytes == 8) {
|
||||
*(int2 *) dst = *(const int2 *) src;
|
||||
} else if constexpr (nbytes == 16) {
|
||||
*(int4 *) dst = *(const int4 *) src;
|
||||
} else {
|
||||
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
|
||||
if constexpr (alignment != 0) {
|
||||
static_assert(nbytes % alignment == 0, "bad alignment");
|
||||
}
|
||||
constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
|
||||
if constexpr (nb_per_cpy == 1) {
|
||||
((char *) dst)[i] = ((const char *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 2) {
|
||||
((short *) dst)[i] = ((const short *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 4) {
|
||||
((int *) dst)[i] = ((const int *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 8) {
|
||||
((int2 *) dst)[i] = ((const int2 *) src)[i];
|
||||
} else if constexpr (nb_per_cpy == 16) {
|
||||
((int4 *) dst)[i] = ((const int4 *) src)[i];
|
||||
} else {
|
||||
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,276 +33,230 @@ typedef void (* fattn_kernel_t)(
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
||||
|
||||
typedef half (*vec_dot_KQ_f16_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
typedef float (*vec_dot_KQ_f32_t)(
|
||||
typedef float (*vec_dot_KQ_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
|
||||
template<typename T, int D, int warp_size>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI4_0;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/warp_size];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
||||
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T, int D, int warp_size>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI4_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/warp_size];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
||||
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
||||
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
||||
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
||||
|
||||
sum += (T) (sumid4d8 + m4s8scaled);
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T, int D, int warp_size>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI5_0;
|
||||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||
|
||||
const int u = Q_q8[k_KQ_0/warp_size];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
|
||||
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T, int D, int warp_size>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI5_1;
|
||||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||
|
||||
const int u = Q_q8[k_KQ_0/warp_size];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
|
||||
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
||||
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
|
||||
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
|
||||
|
||||
sum += (T) (sumid5d8 + m5s8scaled);
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, int D, int warp_size>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_0;
|
||||
const int iqs = k_KQ % QI8_0;
|
||||
|
||||
const int v = get_int_b2(K_q8_0[ib].qs, iqs);
|
||||
|
||||
T Q_d;
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
|
||||
} else {
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
Q_d = Q_ds[k_KQ_0/warp_size].x;
|
||||
}
|
||||
|
||||
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, int D, int warp_size>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||
template <int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const half2 * K_h2 = (const half2 *) K_c;
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_h2 = (const half2 *) Q_v;
|
||||
|
||||
half2 sum2 = make_half2(0.0f, 0.0f);
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[k_KQ];
|
||||
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
|
||||
}
|
||||
|
||||
return __low2half(sum2) + __high2half(sum2);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) Q_v;
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[k_KQ];
|
||||
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
|
||||
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
|
||||
half2 tmp[cpy_ne];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename Tds>
|
||||
template<int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
||||
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI4_0;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v;
|
||||
ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
|
||||
v = (v >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/nthreads];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
||||
sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
||||
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI4_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v;
|
||||
ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
|
||||
v = (v >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/nthreads];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
const float2 K_dm = __half22float2(K_q4_1[ib].dm);
|
||||
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
||||
|
||||
sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
||||
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI5_0;
|
||||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v;
|
||||
ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
|
||||
v = (v >> shift) & 0x0F0F0F0F;
|
||||
|
||||
{
|
||||
int vh;
|
||||
ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
|
||||
vh >>= iqs8 * QI5_0;
|
||||
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||
}
|
||||
|
||||
const int u = Q_q8[k_KQ_0/nthreads];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
||||
|
||||
sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
||||
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI5_1;
|
||||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v;
|
||||
ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
|
||||
v = (v >> shift) & 0x0F0F0F0F;
|
||||
|
||||
{
|
||||
int vh;
|
||||
ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
|
||||
vh >>= iqs8 * QI5_0;
|
||||
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||
}
|
||||
|
||||
const int u = Q_q8[k_KQ_0/nthreads];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
||||
const float2 K_dm = __half22float2(K_q5_1[ib].dm);
|
||||
const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
|
||||
|
||||
sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
|
||||
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
|
||||
|
||||
const int ib = k_KQ / QI8_0;
|
||||
const int iqs = k_KQ % QI8_0;
|
||||
|
||||
int v;
|
||||
ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
|
||||
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
const float Q_d = Q_ds[k_KQ_0/nthreads].x;
|
||||
|
||||
sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename Tds, int ni>
|
||||
static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
||||
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
|
||||
|
||||
float vals[sizeof(int)] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int l = 0; l < int(sizeof(int)); ++l) {
|
||||
vals[l] = scale * x[4*threadIdx.x + l];
|
||||
vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
|
||||
}
|
||||
|
||||
float amax = fabsf(vals[0]);
|
||||
@@ -330,7 +284,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
||||
}
|
||||
|
||||
yq32[threadIdx.x] = q32;
|
||||
if (threadIdx.x % QI8_1 == 0) {
|
||||
if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
|
||||
if (std::is_same<Tds, half2>::value) {
|
||||
((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
|
||||
} else {
|
||||
@@ -339,167 +293,276 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
||||
}
|
||||
}
|
||||
|
||||
typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
|
||||
typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
|
||||
typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
if constexpr (std::is_same_v<T, half>) {
|
||||
ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
static_assert(ne % 2 == 0, "bad ne");
|
||||
half2 tmp[ne/2];
|
||||
ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
|
||||
float2 * dst_f2 = (float2 *) dst;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne/2; ++l) {
|
||||
dst_f2[l] = __half22float2(tmp[l]);
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const int64_t ib = i / QK4_0;
|
||||
const int iqs = i % (QK4_0/2);
|
||||
const int shift = (i % QK4_0) / (QK4_0/2);
|
||||
const int64_t ib = i0 / QK4_0;
|
||||
const int iqs = i0 % (QK4_0/2);
|
||||
const int shift = (i0 % QK4_0) / (QK4_0/2);
|
||||
|
||||
const T d = x[ib].d;
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
|
||||
int q;
|
||||
static_assert(ne == 2 || ne == 4, "bad ne");
|
||||
ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
||||
q >>= 4*shift;
|
||||
q &= 0x0F0F0F0F;
|
||||
q = __vsubss4(q, 0x08080808);
|
||||
|
||||
const int8_t * q8 = (const int8_t *) &q;
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, half>) {
|
||||
const half2 d = __half2half2(x[ib].d);
|
||||
|
||||
return ((float) d)*((float) q);
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < ne; l0 += 2) {
|
||||
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
|
||||
}
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const float d = x[ib].d;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
((float *) dst)[l] = d * q8[l];
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "bad type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
const int64_t ib = i / QK4_1;
|
||||
const int iqs = i % (QK4_1/2);
|
||||
const int shift = (i % QK4_1) / (QK4_1/2);
|
||||
const int64_t ib = i0 / QK4_1;
|
||||
const int iqs = i0 % (QK4_1/2);
|
||||
const int shift = (i0 % QK4_1) / (QK4_1/2);
|
||||
|
||||
const half2 dm = x[ib].dm;
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F);
|
||||
int q;
|
||||
static_assert(ne == 2 || ne == 4, "bad ne");
|
||||
ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
||||
q >>= 4*shift;
|
||||
q &= 0x0F0F0F0F;
|
||||
|
||||
const int8_t * q8 = (const int8_t *) &q;
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, half>) {
|
||||
const half2 dm = x[ib].dm;
|
||||
const half2 d = __half2half2( __low2half(dm));
|
||||
const half2 m = __half2half2(__high2half(dm));
|
||||
|
||||
return __low2float(dm)*((float) q) + __high2float(dm);
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < ne; l0 += 2) {
|
||||
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
||||
}
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const float2 dm = __half22float2(x[ib].dm);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
((float *) dst)[l] = dm.x * q8[l] + dm.y;
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "bad type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||
|
||||
const int64_t ib = i / QK5_0;
|
||||
const int idq = i % QK5_0;
|
||||
const int iqs = i % (QK5_0/2);
|
||||
const int shift = (i % QK5_0) / (QK5_0/2);
|
||||
const int64_t ib = i0 / QK5_0;
|
||||
const int idq = i0 % QK5_0;
|
||||
const int iqs = i0 % (QK5_0/2);
|
||||
const int shift = (i0 % QK5_0) / (QK5_0/2);
|
||||
|
||||
const T d = x[ib].d;
|
||||
const int ql0 = x[ib].qs[iqs];
|
||||
const int qh0 = get_int_b2(x[ib].qh, 0);
|
||||
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh) - 16;
|
||||
int q;
|
||||
static_assert(ne == 2 || ne == 4, "bad ne");
|
||||
ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
|
||||
q >>= 4*shift;
|
||||
q &= 0x0F0F0F0F;
|
||||
|
||||
{
|
||||
int qh;
|
||||
ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
||||
}
|
||||
}
|
||||
|
||||
q = __vsubss4(q, 0x10101010);
|
||||
|
||||
const int8_t * q8 = (const int8_t *) &q;
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, half>) {
|
||||
const half2 d = __half2half2(x[ib].d);
|
||||
|
||||
return ((float) d)*((float) q);
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < ne; l0 += 2) {
|
||||
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
|
||||
}
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const float d = x[ib].d;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
((float *) dst)[l] = d * q8[l];
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "bad type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
const int64_t ib = i / QK5_1;
|
||||
const int idq = i % QK5_1;
|
||||
const int iqs = i % (QK5_1/2);
|
||||
const int shift = (i % QK5_1) / (QK5_1/2);
|
||||
const int64_t ib = i0 / QK5_1;
|
||||
const int idq = i0 % QK5_1;
|
||||
const int iqs = i0 % (QK5_1/2);
|
||||
const int shift = (i0 % QK5_1) / (QK5_1/2);
|
||||
|
||||
const half2 dm = x[ib].dm;
|
||||
const int ql0 = x[ib].qs[iqs];
|
||||
const int qh0 = get_int_b4(x[ib].qh, 0);
|
||||
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh);
|
||||
int q;
|
||||
static_assert(ne == 2 || ne == 4, "bad ne");
|
||||
ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
|
||||
q >>= 4*shift;
|
||||
q &= 0x0F0F0F0F;
|
||||
|
||||
{
|
||||
int qh;
|
||||
ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
|
||||
}
|
||||
}
|
||||
|
||||
const int8_t * q8 = (const int8_t *) &q;
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, half>) {
|
||||
const half2 dm = x[ib].dm;
|
||||
const half2 d = __half2half2( __low2half(dm));
|
||||
const half2 m = __half2half2(__high2half(dm));
|
||||
|
||||
return __low2float(dm)*((float) q) + __high2float(dm);
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < ne; l0 += 2) {
|
||||
((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
|
||||
}
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const float2 dm = __half22float2(x[ib].dm);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
((float *) dst)[l] = dm.x * q8[l] + dm.y;
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "bad type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
|
||||
template <typename T, int ne>
|
||||
static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
const int64_t ib = i / QK8_0;
|
||||
const int iqs = i % QK8_0;
|
||||
const int64_t ib = i0 / QK8_0;
|
||||
const int iqs = i0 % QK8_0;
|
||||
|
||||
const T d = x[ib].d;
|
||||
const int q = x[ib].qs[iqs];
|
||||
static_assert(ne % 2 == 0, "bad ne");
|
||||
int8_t qs[ne];
|
||||
ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
|
||||
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
const half2 d = __half2half2(x[ib].d);
|
||||
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < ne; l0 += 2) {
|
||||
((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
|
||||
}
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
const float d = x[ib].d;
|
||||
|
||||
return ((float) d)*((float) q);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < ne; ++l) {
|
||||
((float *) dst)[l] = d * qs[l];
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
return x[i];
|
||||
template <ggml_type type_K, int D, int nthreads>
|
||||
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
||||
if constexpr (type_K == GGML_TYPE_F16) {
|
||||
return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_Q4_0) {
|
||||
return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_Q4_1) {
|
||||
return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_Q5_0) {
|
||||
return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_Q5_1) {
|
||||
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
|
||||
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
|
||||
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
|
||||
} else {
|
||||
static_assert(type_K == -1, "bad type");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int warp_size = WARP_SIZE>
|
||||
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
|
||||
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
|
||||
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
|
||||
nullptr;
|
||||
}
|
||||
|
||||
template <int D, int warp_size = WARP_SIZE>
|
||||
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
|
||||
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
|
||||
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
|
||||
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
|
||||
nullptr;
|
||||
}
|
||||
|
||||
constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
|
||||
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
|
||||
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
|
||||
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
|
||||
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
|
||||
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
|
||||
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
|
||||
nullptr;
|
||||
}
|
||||
|
||||
constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
|
||||
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
|
||||
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
|
||||
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
|
||||
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
|
||||
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
|
||||
nullptr;
|
||||
template <ggml_type type_V, typename T, int ne>
|
||||
constexpr __device__ dequantize_V_t get_dequantize_V() {
|
||||
if constexpr (type_V == GGML_TYPE_F16) {
|
||||
return dequantize_V_f16<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_Q4_0) {
|
||||
return dequantize_V_q4_0<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_Q4_1) {
|
||||
return dequantize_V_q4_1<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_Q5_0) {
|
||||
return dequantize_V_q5_0<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_Q5_1) {
|
||||
return dequantize_V_q5_1<T, ne>;
|
||||
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
|
||||
return dequantize_V_q8_0<T, ne>;
|
||||
} else {
|
||||
static_assert(type_V == -1, "bad type");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
template <int ncols1>
|
||||
@@ -870,7 +933,7 @@ void launch_fattn(
|
||||
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
||||
|
||||
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
||||
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
|
||||
if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,495 +0,0 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
// Currenlty llvm with the amdgcn target dose not support unrolling loops
|
||||
// that contain a break that can not be resolved at compile time.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#ifndef GGML_USE_HIP
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // GGML_USE_HIP
|
||||
static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
if (ncols > 1) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
__shared__ half KQ[ncols*D];
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half kqmax[ncols];
|
||||
half kqsum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax[j] = -HALF_MAX_HALF;
|
||||
kqsum[j] = 0.0f;
|
||||
}
|
||||
|
||||
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
||||
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.y == 0) {
|
||||
kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
|
||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ half maskh_shared[ncols*D];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskh_shared[j*D + tid] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
|
||||
half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
|
||||
if (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 2 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
||||
Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
||||
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ[j*D + tid] = -HALF_MAX_HALF;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
K += blockIdx.y*D * nb11;
|
||||
V += blockIdx.y*D * nb21;
|
||||
maskh += blockIdx.y*D;
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
||||
half kqmax_new = kqmax[0];
|
||||
half kqmax_new_arr[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax_new_arr[j] = kqmax[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||
break;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum((float)sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
sum += maskh_shared[j*D + i_KQ];
|
||||
|
||||
if (ncols == 1) {
|
||||
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||
} else {
|
||||
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[j*D + i_KQ] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
||||
KQ[j*D + tid] = val;
|
||||
|
||||
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D; k0 += 2) {
|
||||
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 V_k;
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const half sink = __float2half(sinksf[head]);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const half val = hexp(sink - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||
|
||||
if (tid == 0) {
|
||||
kqsum[j] += val;
|
||||
}
|
||||
|
||||
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
|
||||
break;
|
||||
}
|
||||
|
||||
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
||||
|
||||
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
template void ggml_cuda_flash_attn_ext_vec_f16_case \
|
||||
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -1,486 +0,0 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
// Currenlty llvm with the amdgcn target dose not support unrolling loops
|
||||
// that contain a break that can not be resolved at compile time.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#ifndef GGML_USE_HIP
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // GGML_USE_HIP
|
||||
static __global__ void flash_attn_vec_ext_f32(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
if (ncols > 1) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
__shared__ float KQ[ncols*D];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ[j*D + tid] = -FLT_MAX/2.0f;
|
||||
}
|
||||
|
||||
float kqmax[ncols];
|
||||
float kqsum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax[j] = -FLT_MAX/2.0f;
|
||||
kqsum[j] = 0.0f;
|
||||
}
|
||||
|
||||
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
||||
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.y == 0) {
|
||||
kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
|
||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float maskf_shared[ncols*D];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskf_shared[j*D + tid] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
float2 Q_f2[ncols][D/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
|
||||
float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
|
||||
if (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 2 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
||||
Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
||||
Q_f2[j][i0/WARP_SIZE].x *= scale;
|
||||
Q_f2[j][i0/WARP_SIZE].y *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float VKQ[ncols] = {0.0f};
|
||||
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
K += blockIdx.y*D * nb11;
|
||||
V += blockIdx.y*D * nb21;
|
||||
maskh += blockIdx.y*D;
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float kqmax_new_arr[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax_new_arr[j] = kqmax[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||
break;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
sum += maskf_shared[j*D + i_KQ];
|
||||
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[j*D + i_KQ] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float kqmax_new_j = kqmax_new_arr[j];
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const float val = expf(KQ[j*D + tid] - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
||||
KQ[j*D + tid] = val;
|
||||
|
||||
VKQ[j] *= KQ_max_scale;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < D; ++k) {
|
||||
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float V_ki = dequantize_1_v(V + k*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_ki*KQ[j*D + k];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const float val = expf(sink - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||
|
||||
if (tid == 0) {
|
||||
kqsum[j] += val;
|
||||
}
|
||||
|
||||
VKQ[j] *= KQ_max_scale;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
|
||||
break;
|
||||
}
|
||||
|
||||
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||
|
||||
float dst_val = VKQ[j_VKQ];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||
template void ggml_cuda_flash_attn_ext_vec_f32_case \
|
||||
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
593
ggml/src/ggml-cuda/fattn-vec.cuh
Normal file
593
ggml/src/ggml-cuda/fattn-vec.cuh
Normal file
@@ -0,0 +1,593 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
|
||||
return 128;
|
||||
GGML_UNUSED(cc);
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
|
||||
return 128;
|
||||
}
|
||||
|
||||
// Currenlty llvm with the amdgcn target dose not support unrolling loops
|
||||
// that contain a break that can not be resolved at compile time.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
|
||||
static __global__ void flash_attn_ext_vec(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
constexpr int nthreads_KQ_q = 2;
|
||||
#else
|
||||
constexpr int nthreads_KQ_q = 4;
|
||||
#endif // RDNA
|
||||
constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
|
||||
#else
|
||||
constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32);
|
||||
constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
|
||||
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
|
||||
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
|
||||
|
||||
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
|
||||
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
|
||||
|
||||
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
|
||||
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
|
||||
|
||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||
#else
|
||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = nthreads / WARP_SIZE;
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
__builtin_assume(tid < nthreads);
|
||||
|
||||
constexpr int ne_KQ = ncols*D;
|
||||
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||
#else
|
||||
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
float KQ_max[ncols];
|
||||
float KQ_sum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_max[j] = -FLT_MAX/2.0f;
|
||||
KQ_sum[j] = 0.0f;
|
||||
}
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
||||
#else
|
||||
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
||||
if constexpr (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 1 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
} else {
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
|
||||
quantize_q8_1_to_shared<float2, nthreads_quantize>
|
||||
(Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ);
|
||||
|
||||
Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
|
||||
Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 scale_h2 = make_half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_j = (const float2 *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
||||
|
||||
float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
|
||||
if (ncols == 1 || ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne; ++i1) {
|
||||
Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
|
||||
Q_reg[j][k] *= scale_h2;
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_j = (const float2 *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
||||
if (ncols == 1 || ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
|
||||
Q_reg[j][k].x *= scale;
|
||||
Q_reg[j][k].y *= scale;
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
K += blockIdx.y*nthreads * nb11;
|
||||
V += blockIdx.y*nthreads * nb21;
|
||||
maskh += blockIdx.y*nthreads;
|
||||
for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
float KQ_reg[ncols]; // KQ in registers.
|
||||
|
||||
float KQ_max_new[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
|
||||
const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum<nthreads_KQ>(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
|
||||
}
|
||||
|
||||
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
|
||||
|
||||
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
|
||||
KQ_reg[j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
|
||||
KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
|
||||
}
|
||||
const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
|
||||
KQ_max[j] = KQ_max_new[j];
|
||||
|
||||
KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
|
||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
||||
KQ[j*nthreads + tid] = KQ_reg[j];
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#ifndef GGML_USE_HIP
|
||||
__syncwarp();
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
||||
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 KQ_k[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
half2 tmp[V_rows_per_thread/2];
|
||||
dequantize_V(V + k*nb21, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
float KQ_k[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_k[j] = KQ[j*nthreads + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
float2 tmp[V_rows_per_thread/2];
|
||||
dequantize_V(V + k*nb21, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
if (sinks && blockIdx.y == 0) {
|
||||
const float sink = ((const float *) sinks)[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
|
||||
const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
|
||||
KQ_max[j] = kqmax_new_j;
|
||||
|
||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float KQ_max_shared[ncols][WARP_SIZE];
|
||||
__shared__ float KQ_sum_shared[ncols][WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.y == 0) {
|
||||
KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
|
||||
KQ_sum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_max_shared[j][threadIdx.y] = KQ_max[j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||
if (ncols > 1 && ic0 + j_VKQ >= ne01) {
|
||||
break;
|
||||
}
|
||||
|
||||
float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
|
||||
kqmax_new = warp_reduce_max(kqmax_new);
|
||||
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
||||
KQ_max[j_VKQ] = kqmax_new;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
||||
|
||||
const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
|
||||
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||
}
|
||||
#else
|
||||
float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
|
||||
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
KQ_sum[j_VKQ] *= kqmax_scale;
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (nthreads <= D || tid < D) {
|
||||
KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += nthreads) {
|
||||
float dst_val = 0;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < nwarps; ++w) {
|
||||
#pragma unroll
|
||||
for (int v = 0; v < V_cols_per_iter; ++v) {
|
||||
dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
|
||||
}
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_sum[j_VKQ];
|
||||
}
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (j_VKQ < ncols-1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = false;
|
||||
constexpr bool need_f16_V = false;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
template void ggml_cuda_flash_attn_ext_vec_case \
|
||||
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
|
||||
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||
@@ -2,8 +2,7 @@
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-mma-f16.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
#include "fattn-vec-f16.cuh"
|
||||
#include "fattn-vec-f32.cuh"
|
||||
#include "fattn-vec.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
@@ -117,151 +116,68 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
}
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
FATTN_VEC_CASE(128, type_K, type_V) \
|
||||
FATTN_VEC_CASE(256, type_K, type_V) \
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
ggml_tensor * K = dst->src[1];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
||||
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
||||
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
||||
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
||||
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#else
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
ggml_tensor * K = dst->src[1];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
#else
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -271,8 +187,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
enum best_fattn_kernel {
|
||||
BEST_FATTN_KERNEL_NONE = 0,
|
||||
BEST_FATTN_KERNEL_TILE = 200,
|
||||
BEST_FATTN_KERNEL_VEC_F32 = 100,
|
||||
BEST_FATTN_KERNEL_VEC_F16 = 110,
|
||||
BEST_FATTN_KERNEL_VEC = 100,
|
||||
BEST_FATTN_KERNEL_WMMA_F16 = 300,
|
||||
BEST_FATTN_KERNEL_MMA_F16 = 400,
|
||||
};
|
||||
@@ -283,7 +198,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
#endif// FLASH_ATTN_AVAILABLE
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
@@ -293,8 +207,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 64:
|
||||
@@ -343,31 +255,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
if (K->ne[0] != 128 && K->ne[0] != 64) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
#else
|
||||
if (K->ne[0] != 128) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
switch (V->type) {
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
if (K->ne[0] != 128) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
@@ -377,30 +264,39 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
|
||||
|
||||
// If Turing tensor cores available, use them except for some cases with batch size 1:
|
||||
if (turing_mma_available(cc)) {
|
||||
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192);
|
||||
const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion &&
|
||||
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
return BEST_FATTN_KERNEL_VEC_F16;
|
||||
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
|
||||
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
if (Q->ne[1] <= 2) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
|
||||
}
|
||||
return BEST_FATTN_KERNEL_VEC_F32;
|
||||
}
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
|
||||
return best;
|
||||
}
|
||||
|
||||
// Use kernels specializes for small batch sizes if possible:
|
||||
// Use kernels specialized for small batch sizes if possible:
|
||||
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
return BEST_FATTN_KERNEL_VEC_F16;
|
||||
}
|
||||
return BEST_FATTN_KERNEL_VEC_F32;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
||||
// For large batch sizes, use the WMMA kernel if possible:
|
||||
@@ -420,11 +316,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
case BEST_FATTN_KERNEL_TILE:
|
||||
ggml_cuda_flash_attn_ext_tile(ctx, dst);
|
||||
break;
|
||||
case BEST_FATTN_KERNEL_VEC_F32:
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
break;
|
||||
case BEST_FATTN_KERNEL_VEC_F16:
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
case BEST_FATTN_KERNEL_VEC:
|
||||
ggml_cuda_flash_attn_ext_vec(ctx, dst);
|
||||
break;
|
||||
case BEST_FATTN_KERNEL_WMMA_F16:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
@@ -1,5 +0,0 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,7 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user