mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: use tensor cores for MMQ (#7676)
* CUDA: int8 tensor cores for MMQ (legacy quants) * fix out-of-bounds writes * __builtin_assume -> GGML_CUDA_ASSUME * fix writeback returning too early
This commit is contained in:
		| @@ -139,6 +139,7 @@ | |||||||
| #define CC_PASCAL     600 | #define CC_PASCAL     600 | ||||||
| #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products | #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products | ||||||
| #define CC_VOLTA      700 | #define CC_VOLTA      700 | ||||||
|  | #define CC_TURING     750 | ||||||
| #define CC_AMPERE     800 | #define CC_AMPERE     800 | ||||||
| #define CC_OFFSET_AMD 1000000 | #define CC_OFFSET_AMD 1000000 | ||||||
| #define CC_RDNA1      (CC_OFFSET_AMD + 1010) | #define CC_RDNA1      (CC_OFFSET_AMD + 1010) | ||||||
| @@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int | |||||||
| #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 | #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 | ||||||
| #endif // defined(GGML_USE_HIPBLAS) | #endif // defined(GGML_USE_HIPBLAS) | ||||||
|  |  | ||||||
| #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL | #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL | ||||||
|  | #define FP16_AVAILABLE | ||||||
|  | #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL | ||||||
|  |  | ||||||
| #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA | #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA | ||||||
|  | #define FP16_MMA_AVAILABLE | ||||||
|  | #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA | ||||||
|  |  | ||||||
|  | #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING | ||||||
|  | #define INT8_MMA_AVAILABLE | ||||||
|  | #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING | ||||||
|  |  | ||||||
| static bool fast_fp16_available(const int cc) { | static bool fast_fp16_available(const int cc) { | ||||||
|     return cc >= CC_PASCAL && cc != 610; |     return cc >= CC_PASCAL && cc != 610; | ||||||
| @@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) { | |||||||
|     return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; |     return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static bool int8_mma_available(const int cc) { | ||||||
|  |     return cc < CC_OFFSET_AMD && cc >= CC_TURING; | ||||||
|  | } | ||||||
|  |  | ||||||
| [[noreturn]] | [[noreturn]] | ||||||
| static __device__ void no_device_code( | static __device__ void no_device_code( | ||||||
|     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { |     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { | ||||||
| @@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { | |||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { | static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|  |  | ||||||
| #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) | #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) | ||||||
| #pragma unroll | #pragma unroll | ||||||
| @@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { | |||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { | static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|  |  | ||||||
| #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX | #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX | ||||||
|     return __float2half(fmaxf(__half2float(a), __half2float(b))); |     return __float2half(fmaxf(__half2float(a), __half2float(b))); | ||||||
|   | |||||||
| @@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( | |||||||
|  |  | ||||||
|         const int sumi = __dp4a(v, u, 0); |         const int sumi = __dp4a(v, u, 0); | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|         if (std::is_same<T, half>::value) { |         if (std::is_same<T, half>::value) { | ||||||
|             const half2  * Q_ds = (const half2  *) Q_ds_v; |             const half2  * Q_ds = (const half2  *) Q_ds_v; | ||||||
|  |  | ||||||
| @@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( | |||||||
|  |  | ||||||
|         const int sumi = __dp4a(v, u, 0); |         const int sumi = __dp4a(v, u, 0); | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|         if (std::is_same<T, half>::value) { |         if (std::is_same<T, half>::value) { | ||||||
|             const half2  * Q_ds = (const half2  *) Q_ds_v; |             const half2  * Q_ds = (const half2  *) Q_ds_v; | ||||||
|  |  | ||||||
| @@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( | |||||||
|  |  | ||||||
|         const int sumi = __dp4a(v, u, 0); |         const int sumi = __dp4a(v, u, 0); | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|         if (std::is_same<T, half>::value) { |         if (std::is_same<T, half>::value) { | ||||||
|             const half2  * Q_ds = (const half2  *) Q_ds_v; |             const half2  * Q_ds = (const half2  *) Q_ds_v; | ||||||
|  |  | ||||||
| @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( | |||||||
|  |  | ||||||
|         const int sumi = __dp4a(v, u, 0); |         const int sumi = __dp4a(v, u, 0); | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|         if (std::is_same<T, half>::value) { |         if (std::is_same<T, half>::value) { | ||||||
|             const half2  * Q_ds = (const half2  *) Q_ds_v; |             const half2  * Q_ds = (const half2  *) Q_ds_v; | ||||||
|  |  | ||||||
| @@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( | |||||||
|     GGML_UNUSED(Q_q8); |     GGML_UNUSED(Q_q8); | ||||||
|     GGML_UNUSED(Q_ds_v); |     GGML_UNUSED(Q_ds_v); | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     if (std::is_same<T, half>::value) { |     if (std::is_same<T, half>::value) { | ||||||
|         const half2 * Q_h2 = (const half2 *) Q_v; |         const half2 * Q_h2 = (const half2 *) Q_v; | ||||||
|  |  | ||||||
| @@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ | |||||||
|     const int q0 = x[ib].qs[iqs]; |     const int q0 = x[ib].qs[iqs]; | ||||||
|     const int q  = ((q0 >> (4*shift)) & 0x0F) - 8; |     const int q  = ((q0 >> (4*shift)) & 0x0F) - 8; | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     if (std::is_same<T, half>::value) { |     if (std::is_same<T, half>::value) { | ||||||
|         return ((half) d)*((half) q); |         return ((half) d)*((half) q); | ||||||
|     } |     } | ||||||
| @@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ | |||||||
|     const int   q0 = x[ib].qs[iqs]; |     const int   q0 = x[ib].qs[iqs]; | ||||||
|     const int   q  = ((q0 >> (4*shift)) & 0x0F); |     const int   q  = ((q0 >> (4*shift)) & 0x0F); | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     if (std::is_same<T, half>::value) { |     if (std::is_same<T, half>::value) { | ||||||
|         return __low2half(dm)*((half) q) + __high2half(dm); |         return __low2half(dm)*((half) q) + __high2half(dm); | ||||||
|     } |     } | ||||||
| @@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ | |||||||
|     const int qh  = ((qh0 >> idq) << 4) & 0x10; |     const int qh  = ((qh0 >> idq) << 4) & 0x10; | ||||||
|     const int q   = (ql | qh) - 16; |     const int q   = (ql | qh) - 16; | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     if (std::is_same<T, half>::value) { |     if (std::is_same<T, half>::value) { | ||||||
|         return ((half) d)*((half) q); |         return ((half) d)*((half) q); | ||||||
|     } |     } | ||||||
| @@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ | |||||||
|     const int   qh  = ((qh0 >> idq) << 4) & 0x10; |     const int   qh  = ((qh0 >> idq) << 4) & 0x10; | ||||||
|     const int   q   = (ql | qh); |     const int   q   = (ql | qh); | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     if (std::is_same<T, half>::value) { |     if (std::is_same<T, half>::value) { | ||||||
|         return __low2half(dm)*((half) q) + __high2half(dm); |         return __low2half(dm)*((half) q) + __high2half(dm); | ||||||
|     } |     } | ||||||
| @@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ | |||||||
|     const T   d = x[ib].d; |     const T   d = x[ib].d; | ||||||
|     const int q = x[ib].qs[iqs]; |     const int q = x[ib].qs[iqs]; | ||||||
|  |  | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     if (std::is_same<T, half>::value) { |     if (std::is_same<T, half>::value) { | ||||||
|         return ((half) d)*((half) q); |         return ((half) d)*((half) q); | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. |     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||||||
|  |  | ||||||
|     const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. |     const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. | ||||||
|   | |||||||
| @@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #if FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. |     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||||||
|  |  | ||||||
|     constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K); |     constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K); | ||||||
|   | |||||||
| @@ -1,9 +1,9 @@ | |||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
| #include "fattn-common.cuh" | #include "fattn-common.cuh" | ||||||
|  |  | ||||||
| #if FP16_MMA_AVAILABLE | #ifdef FP16_MMA_AVAILABLE | ||||||
| #include <mma.h> | #include <mma.h> | ||||||
| #endif | #endif // FP16_MMA_AVAILABLE | ||||||
|  |  | ||||||
| // D == head size, VKQ_stride == num VKQ rows calculated in parallel: | // D == head size, VKQ_stride == num VKQ rows calculated in parallel: | ||||||
| template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t> | template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t> | ||||||
| @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #if FP16_MMA_AVAILABLE | #ifdef FP16_MMA_AVAILABLE | ||||||
|     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. |     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||||||
|  |  | ||||||
|     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. |     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. | ||||||
|   | |||||||
							
								
								
									
										95
									
								
								ggml-cuda/mma.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								ggml-cuda/mma.cuh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,95 @@ | |||||||
|  | #include "common.cuh" | ||||||
|  |  | ||||||
|  | struct mma_int_A_I16K8 { | ||||||
|  |     static constexpr int I  = 16; | ||||||
|  |     static constexpr int K  = 8; | ||||||
|  |     static constexpr int ne = 4; | ||||||
|  |  | ||||||
|  |     int x[ne] = {0}; | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_i(const int l) { | ||||||
|  |         const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  I); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_k(const int l) { | ||||||
|  |         const int ret = (l/2) * (K/2) + threadIdx.x % (K/2); | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  K); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct mma_int_B_J8K8 { | ||||||
|  |     static constexpr int J  = 8; | ||||||
|  |     static constexpr int K  = 8; | ||||||
|  |     static constexpr int ne = 2; | ||||||
|  |  | ||||||
|  |     int x[ne] = {0}; | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_j(const int /* l */) { | ||||||
|  |         const int ret = threadIdx.x / (K/2); | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  J); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_k(const int l) { | ||||||
|  |         const int ret = l * (K/2) + threadIdx.x % (K/2); | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  K); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct mma_int_C_I16J8 { | ||||||
|  |     static constexpr int I  = 16; | ||||||
|  |     static constexpr int J  = 8; | ||||||
|  |     static constexpr int ne = 4; | ||||||
|  |  | ||||||
|  |     int x[ne] = {0}; | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_i(const int l) { | ||||||
|  |         const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  I); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_j(const int l) { | ||||||
|  |         const int ret = 2 * (threadIdx.x % (J/2)) + l%2; | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  J); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { | ||||||
|  | #ifdef INT8_MMA_AVAILABLE | ||||||
|  | #if __CUDA_ARCH__ >= CC_AMPERE | ||||||
|  |         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||||
|  |             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) | ||||||
|  |             : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); | ||||||
|  | #else | ||||||
|  |         // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: | ||||||
|  |         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" | ||||||
|  |             : "+r"(x[0]), "+r"(x[1]) | ||||||
|  |             : "r"(mma_A.x[0]), "r"(mma_B.x[0])); | ||||||
|  |         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" | ||||||
|  |             : "+r"(x[2]), "+r"(x[3]) | ||||||
|  |             : "r"(mma_A.x[1]), "r"(mma_B.x[0])); | ||||||
|  |         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" | ||||||
|  |             : "+r"(x[0]), "+r"(x[1]) | ||||||
|  |             : "r"(mma_A.x[2]), "r"(mma_B.x[1])); | ||||||
|  |         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" | ||||||
|  |             : "+r"(x[2]), "+r"(x[3]) | ||||||
|  |             : "r"(mma_A.x[3]), "r"(mma_B.x[1])); | ||||||
|  | #endif // __CUDA_ARCH__ >= CC_AMPERE | ||||||
|  | #else | ||||||
|  |         GGML_UNUSED(mma_A); | ||||||
|  |         GGML_UNUSED(mma_B); | ||||||
|  |         NO_DEVICE_CODE; | ||||||
|  | #endif // INT8_MMA_AVAILABLE | ||||||
|  |     } | ||||||
|  | }; | ||||||
| @@ -2,6 +2,7 @@ | |||||||
|  |  | ||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
| #include "vecdotq.cuh" | #include "vecdotq.cuh" | ||||||
|  | #include "mma.cuh" | ||||||
|  |  | ||||||
| #include <climits> | #include <climits> | ||||||
| #include <cstdint> | #include <cstdint> | ||||||
| @@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)( | |||||||
| typedef void (*vec_dot_mmq_t)( | typedef void (*vec_dot_mmq_t)( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y, float * __restrict__ sum, const int & k0); |     const int * __restrict__ y, float * __restrict__ sum, const int & k0); | ||||||
|  | typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); | ||||||
|  |  | ||||||
| struct block_q8_1_mmq { | struct block_q8_1_mmq { | ||||||
|     half2  ds[4]; |     half2  ds[4]; | ||||||
| @@ -141,13 +143,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps> | template <int mmq_x, int mmq_y, int nwarps> | ||||||
| static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( | static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
|     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | ||||||
|  |  | ||||||
|     const float * x_dmf = (const float *) x_dm; |     const float * x_df = (const float *) x_dm; | ||||||
|     const int   * y_qs = (const int   *) y + 4; |     const int   * y_qs = (const int   *) y + 4; | ||||||
|     const half2 * y_ds = (const half2 *) y; |     const half2 * y_ds = (const half2 *) y; | ||||||
|  |  | ||||||
| @@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ> |             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ> | ||||||
|                 (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], |                 (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], | ||||||
|                 y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); |                 y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <int mmq_x, int mmq_y, int nwarps> | ||||||
|  | static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( | ||||||
|  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
|  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | ||||||
|  |  | ||||||
|  |     typedef mma_int_A_I16K8 mma_A; | ||||||
|  |     typedef mma_int_B_J8K8  mma_B; | ||||||
|  |     typedef mma_int_C_I16J8 mma_C; | ||||||
|  |  | ||||||
|  |     const float * x_df = (const float *) x_dm; | ||||||
|  |     const int   * y_qs = (const int   *) y + 4; | ||||||
|  |     const half2 * y_ds = (const half2 *) y; | ||||||
|  |  | ||||||
|  |     mma_A A; | ||||||
|  |     float dA[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  |     const int i0 = threadIdx.y*mma_A::I; | ||||||
|  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_A::ne; ++l) { | ||||||
|  |         const int i     = i0 + mma_A::get_i(l); | ||||||
|  |         const int k     = k0 + mma_A::get_k(l) % QI4_0; | ||||||
|  |         const int shift =   4*(mma_A::get_k(l) / QI4_0); | ||||||
|  |  | ||||||
|  |         A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); | ||||||
|  |     } | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |         const int i = i0 + mma_C::get_i(2*l); | ||||||
|  |  | ||||||
|  |         dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | ||||||
|  |         mma_C C; | ||||||
|  |         mma_B B; | ||||||
|  |         half2 dsB[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_B::ne; ++l) { | ||||||
|  |             const int j =    j0 + mma_B::get_j(l); | ||||||
|  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | ||||||
|  |  | ||||||
|  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | ||||||
|  |         } | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |             const int j = j0 + mma_C::get_j(l); | ||||||
|  |  | ||||||
|  |             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         C.mma_K8(A, B); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne; ++l) { | ||||||
|  |             sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1( | ||||||
|     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | ||||||
| @@ -215,7 +281,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps> | template <int mmq_x, int mmq_y, int nwarps> | ||||||
| static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( | static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
| @@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <int mmq_x, int mmq_y, int nwarps> | ||||||
|  | static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( | ||||||
|  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
|  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | ||||||
|  |  | ||||||
|  |     typedef mma_int_A_I16K8 mma_A; | ||||||
|  |     typedef mma_int_B_J8K8  mma_B; | ||||||
|  |     typedef mma_int_C_I16J8 mma_C; | ||||||
|  |  | ||||||
|  |     const int   * y_qs = (const int   *) y + 4; | ||||||
|  |     const half2 * y_ds = (const half2 *) y; | ||||||
|  |  | ||||||
|  |     mma_A A; | ||||||
|  |     half2 dmA[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  |     const int i0 = threadIdx.y*mma_A::I; | ||||||
|  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_A::ne; ++l) { | ||||||
|  |         const int i     = i0 + mma_A::get_i(l); | ||||||
|  |         const int k     = k0 + mma_A::get_k(l) % QI4_0; | ||||||
|  |         const int shift =   4*(mma_A::get_k(l) / QI4_0); | ||||||
|  |  | ||||||
|  |         A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; | ||||||
|  |     } | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |         const int i = i0 + mma_C::get_i(2*l); | ||||||
|  |  | ||||||
|  |         dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | ||||||
|  |         mma_C C; | ||||||
|  |         mma_B B; | ||||||
|  |         half2 dsB[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_B::ne; ++l) { | ||||||
|  |             const int j =    j0 + mma_B::get_j(l); | ||||||
|  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | ||||||
|  |  | ||||||
|  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | ||||||
|  |         } | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |             const int j = j0 + mma_C::get_j(l); | ||||||
|  |  | ||||||
|  |             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         C.mma_K8(A, B); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne; ++l) { | ||||||
|  |             const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; | ||||||
|  |             sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0( | ||||||
|     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | ||||||
| @@ -308,7 +438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps> | template <int mmq_x, int mmq_y, int nwarps> | ||||||
| static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( | static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
| @@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <int mmq_x, int mmq_y, int nwarps> | ||||||
|  | static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( | ||||||
|  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
|  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | ||||||
|  |  | ||||||
|  |     typedef mma_int_A_I16K8 mma_A; | ||||||
|  |     typedef mma_int_B_J8K8  mma_B; | ||||||
|  |     typedef mma_int_C_I16J8 mma_C; | ||||||
|  |  | ||||||
|  |     const float * x_df = (const float *) x_dm; | ||||||
|  |     const int   * y_qs = (const int   *) y + 4; | ||||||
|  |     const float * y_df = (const float *) y; | ||||||
|  |  | ||||||
|  |     mma_A A; | ||||||
|  |     float dA[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  |     const int i0 = threadIdx.y*mma_A::I; | ||||||
|  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_A::ne; ++l) { | ||||||
|  |         const int i     =    i0 + mma_A::get_i(l); | ||||||
|  |         const int k     = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0; | ||||||
|  |  | ||||||
|  |         A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; | ||||||
|  |     } | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |         const int i = i0 + mma_C::get_i(2*l); | ||||||
|  |  | ||||||
|  |         dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | ||||||
|  |         mma_C C; | ||||||
|  |         mma_B B; | ||||||
|  |         float dB[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_B::ne; ++l) { | ||||||
|  |             const int j =    j0 + mma_B::get_j(l); | ||||||
|  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | ||||||
|  |  | ||||||
|  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | ||||||
|  |         } | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |             const int j = j0 + mma_C::get_j(l); | ||||||
|  |  | ||||||
|  |             dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         C.mma_K8(A, B); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne; ++l) { | ||||||
|  |             sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1( | ||||||
|     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
| @@ -400,7 +592,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps> | template <int mmq_x, int mmq_y, int nwarps> | ||||||
| static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( | static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
| @@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <int mmq_x, int mmq_y, int nwarps> | ||||||
|  | static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( | ||||||
|  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
|  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | ||||||
|  |  | ||||||
|  |     typedef mma_int_A_I16K8 mma_A; | ||||||
|  |     typedef mma_int_B_J8K8  mma_B; | ||||||
|  |     typedef mma_int_C_I16J8 mma_C; | ||||||
|  |  | ||||||
|  |     const int   * y_qs = (const int   *) y + 4; | ||||||
|  |     const half2 * y_ds = (const half2 *) y; | ||||||
|  |  | ||||||
|  |     mma_A A; | ||||||
|  |     half2 dmA[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  |     const int i0 = threadIdx.y*mma_A::I; | ||||||
|  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_A::ne; ++l) { | ||||||
|  |         const int i     =    i0 + mma_A::get_i(l); | ||||||
|  |         const int k     = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1; | ||||||
|  |  | ||||||
|  |         A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; | ||||||
|  |     } | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |         const int i = i0 + mma_C::get_i(2*l); | ||||||
|  |  | ||||||
|  |         dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | ||||||
|  |         mma_C C; | ||||||
|  |         mma_B B; | ||||||
|  |         half2 dsB[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_B::ne; ++l) { | ||||||
|  |             const int j =    j0 + mma_B::get_j(l); | ||||||
|  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | ||||||
|  |  | ||||||
|  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | ||||||
|  |         } | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |             const int j = j0 + mma_C::get_j(l); | ||||||
|  |  | ||||||
|  |             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         C.mma_K8(A, B); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne; ++l) { | ||||||
|  |             const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; | ||||||
|  |             sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0( | ||||||
|     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | ||||||
| @@ -475,7 +730,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps> | template <int mmq_x, int mmq_y, int nwarps> | ||||||
| static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( | static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
| @@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <int mmq_x, int mmq_y, int nwarps> | ||||||
|  | static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( | ||||||
|  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | ||||||
|  |  | ||||||
|  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | ||||||
|  |  | ||||||
|  |     typedef mma_int_A_I16K8 mma_A; | ||||||
|  |     typedef mma_int_B_J8K8  mma_B; | ||||||
|  |     typedef mma_int_C_I16J8 mma_C; | ||||||
|  |  | ||||||
|  |     const float * x_df = (const float *) x_dm; | ||||||
|  |     const int   * y_qs = (const int   *) y + 4; | ||||||
|  |     const float * y_df = (const float *) y; | ||||||
|  |  | ||||||
|  |     mma_A A; | ||||||
|  |     float dA[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  |     const int i0 = threadIdx.y*mma_A::I; | ||||||
|  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_A::ne; ++l) { | ||||||
|  |         const int i = i0 + mma_A::get_i(l); | ||||||
|  |         const int k = k0 + mma_A::get_k(l); | ||||||
|  |  | ||||||
|  |         A.x[l] = x_ql[i*(WARP_SIZE + 1) + k]; | ||||||
|  |     } | ||||||
|  | #pragma unroll | ||||||
|  |     for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |         const int i = i0 + mma_C::get_i(2*l); | ||||||
|  |  | ||||||
|  |         dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | ||||||
|  |         mma_C C; | ||||||
|  |         mma_B B; | ||||||
|  |         float dB[mma_C::ne/2]; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_B::ne; ++l) { | ||||||
|  |             const int j = j0 + mma_B::get_j(l); | ||||||
|  |             const int k = k0 + mma_B::get_k(l); | ||||||
|  |  | ||||||
|  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | ||||||
|  |         } | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne/2; ++l) { | ||||||
|  |             const int j = j0 + mma_C::get_j(l); | ||||||
|  |  | ||||||
|  |             dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         C.mma_K8(A, B); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne; ++l) { | ||||||
|  |             sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K( | ||||||
|     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | ||||||
| @@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
|  | static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { | ||||||
|  |         const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; | ||||||
|  |  | ||||||
|  |         if (j >= ne1) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { | ||||||
|  |             const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; | ||||||
|  |  | ||||||
|  |             if (need_check && i >= ne0) { | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
|  | static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { | ||||||
|  |     typedef mma_int_C_I16J8 mma_C; | ||||||
|  |  | ||||||
|  |     const int i0 = threadIdx.y*mma_C::I; | ||||||
|  |     static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_C::ne; ++l) { | ||||||
|  |             const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l); | ||||||
|  |  | ||||||
|  |             if (j >= ne1) { | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l); | ||||||
|  |  | ||||||
|  |             if (need_check && i >= ne0) { | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| // ------------------------------------------------------------------------------------------------------------------------------------- | // ------------------------------------------------------------------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type> | template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type> | ||||||
| @@ -998,35 +1367,65 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check> | |||||||
| struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> { | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> { | ||||||
|     static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | #ifdef INT8_MMA_AVAILABLE | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #else | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #endif // INT8_MMA_AVAILABLE | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> { | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> { | ||||||
|     static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | #ifdef INT8_MMA_AVAILABLE | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #else | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #endif // INT8_MMA_AVAILABLE | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> { | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> { | ||||||
|     static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | #ifdef INT8_MMA_AVAILABLE | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #else | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #endif // INT8_MMA_AVAILABLE | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> { | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> { | ||||||
|     static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | #ifdef INT8_MMA_AVAILABLE | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #else | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #endif // INT8_MMA_AVAILABLE | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> { | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> { | ||||||
|     static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | #ifdef INT8_MMA_AVAILABLE | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #else | ||||||
|  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
|  | #endif // INT8_MMA_AVAILABLE | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| @@ -1034,6 +1433,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> { | |||||||
|     static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| @@ -1041,6 +1441,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> { | |||||||
|     static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| @@ -1048,6 +1449,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> { | |||||||
|     static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| @@ -1055,6 +1457,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> { | |||||||
|     static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <int mmq_x, int mmq_y, int nwarps, bool need_check> | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||
| @@ -1062,6 +1465,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> { | |||||||
|     static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ; |     static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ; | ||||||
|     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>; |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>; | ||||||
|     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | ||||||
|  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static int mmq_need_sum(const ggml_type type_x) { | static int mmq_need_sum(const ggml_type type_x) { | ||||||
| @@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q( | |||||||
|     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr; |     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr; | ||||||
|     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles; |     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles; | ||||||
|     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot; |     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot; | ||||||
|  |     constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back; | ||||||
|  |  | ||||||
|     constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type); |     constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type); | ||||||
|  |  | ||||||
| @@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q( | |||||||
|  |  | ||||||
|     const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); |     const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); | ||||||
|  |  | ||||||
|     float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; |     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; | ||||||
|  |  | ||||||
|     for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { |     for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { | ||||||
|  |  | ||||||
| @@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| #pragma unroll |     write_back(sum, dst, ne0, ne1); | ||||||
|     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { |  | ||||||
|         const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; |  | ||||||
|  |  | ||||||
|         if (j >= ne1) { |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { |  | ||||||
|             const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|             if (need_check && i >= ne0) { |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| struct mmq_args { | struct mmq_args { | ||||||
| @@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { | |||||||
|             launch_mul_mat_q<type,   8, 4>(args, stream); |             launch_mul_mat_q<type,   8, 4>(args, stream); | ||||||
|             break; |             break; | ||||||
|         case  16: |         case  16: | ||||||
|             launch_mul_mat_q<type,  16, 8>(args, stream); |             launch_mul_mat_q<type,  16, 4>(args, stream); | ||||||
|             break; |             break; | ||||||
|         case  24: |         case  24: | ||||||
|             launch_mul_mat_q<type,  24, 8>(args, stream); |             launch_mul_mat_q<type,  24, 4>(args, stream); | ||||||
|             break; |             break; | ||||||
|         case  32: |         case  32: | ||||||
|             launch_mul_mat_q<type,  32, 8>(args, stream); |             launch_mul_mat_q<type,  32, 8>(args, stream); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler