|  |  |  | @@ -2,6 +2,7 @@ | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #include "common.cuh" | 
		
	
		
			
				|  |  |  |  | #include "vecdotq.cuh" | 
		
	
		
			
				|  |  |  |  | #include "mma.cuh" | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #include <climits> | 
		
	
		
			
				|  |  |  |  | #include <cstdint> | 
		
	
	
		
			
				
					
					|  |  |  | @@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)( | 
		
	
		
			
				|  |  |  |  | typedef void (*vec_dot_mmq_t)( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0); | 
		
	
		
			
				|  |  |  |  | typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | struct block_q8_1_mmq { | 
		
	
		
			
				|  |  |  |  |     half2  ds[4]; | 
		
	
	
		
			
				
					
					|  |  |  | @@ -141,13 +143,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const float * x_dmf = (const float *) x_dm; | 
		
	
		
			
				|  |  |  |  |     const float * x_df = (const float *) x_dm; | 
		
	
		
			
				|  |  |  |  |     const int   * y_qs = (const int   *) y + 4; | 
		
	
		
			
				|  |  |  |  |     const half2 * y_ds = (const half2 *) y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  |             } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ> | 
		
	
		
			
				|  |  |  |  |                 (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], | 
		
	
		
			
				|  |  |  |  |                 (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], | 
		
	
		
			
				|  |  |  |  |                 y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_A_I16K8 mma_A; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_B_J8K8  mma_B; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_C_I16J8 mma_C; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const float * x_df = (const float *) x_dm; | 
		
	
		
			
				|  |  |  |  |     const int   * y_qs = (const int   *) y + 4; | 
		
	
		
			
				|  |  |  |  |     const half2 * y_ds = (const half2 *) y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     mma_A A; | 
		
	
		
			
				|  |  |  |  |     float dA[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int i0 = threadIdx.y*mma_A::I; | 
		
	
		
			
				|  |  |  |  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_A::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i     = i0 + mma_A::get_i(l); | 
		
	
		
			
				|  |  |  |  |         const int k     = k0 + mma_A::get_k(l) % QI4_0; | 
		
	
		
			
				|  |  |  |  |         const int shift =   4*(mma_A::get_k(l) / QI4_0); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i = i0 + mma_C::get_i(2*l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | 
		
	
		
			
				|  |  |  |  |         mma_C C; | 
		
	
		
			
				|  |  |  |  |         mma_B B; | 
		
	
		
			
				|  |  |  |  |         half2 dsB[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_B::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j =    j0 + mma_B::get_j(l); | 
		
	
		
			
				|  |  |  |  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j = j0 + mma_C::get_j(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         C.mma_K8(A, B); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1( | 
		
	
		
			
				|  |  |  |  |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | 
		
	
		
			
				|  |  |  |  |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | 
		
	
	
		
			
				
					
					|  |  |  | @@ -215,7 +281,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_A_I16K8 mma_A; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_B_J8K8  mma_B; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_C_I16J8 mma_C; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int   * y_qs = (const int   *) y + 4; | 
		
	
		
			
				|  |  |  |  |     const half2 * y_ds = (const half2 *) y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     mma_A A; | 
		
	
		
			
				|  |  |  |  |     half2 dmA[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int i0 = threadIdx.y*mma_A::I; | 
		
	
		
			
				|  |  |  |  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_A::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i     = i0 + mma_A::get_i(l); | 
		
	
		
			
				|  |  |  |  |         const int k     = k0 + mma_A::get_k(l) % QI4_0; | 
		
	
		
			
				|  |  |  |  |         const int shift =   4*(mma_A::get_k(l) / QI4_0); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i = i0 + mma_C::get_i(2*l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | 
		
	
		
			
				|  |  |  |  |         mma_C C; | 
		
	
		
			
				|  |  |  |  |         mma_B B; | 
		
	
		
			
				|  |  |  |  |         half2 dsB[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_B::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j =    j0 + mma_B::get_j(l); | 
		
	
		
			
				|  |  |  |  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j = j0 + mma_C::get_j(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         C.mma_K8(A, B); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; | 
		
	
		
			
				|  |  |  |  |             sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0( | 
		
	
		
			
				|  |  |  |  |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | 
		
	
		
			
				|  |  |  |  |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | 
		
	
	
		
			
				
					
					|  |  |  | @@ -308,7 +438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_A_I16K8 mma_A; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_B_J8K8  mma_B; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_C_I16J8 mma_C; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const float * x_df = (const float *) x_dm; | 
		
	
		
			
				|  |  |  |  |     const int   * y_qs = (const int   *) y + 4; | 
		
	
		
			
				|  |  |  |  |     const float * y_df = (const float *) y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     mma_A A; | 
		
	
		
			
				|  |  |  |  |     float dA[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int i0 = threadIdx.y*mma_A::I; | 
		
	
		
			
				|  |  |  |  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_A::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i     =    i0 + mma_A::get_i(l); | 
		
	
		
			
				|  |  |  |  |         const int k     = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i = i0 + mma_C::get_i(2*l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | 
		
	
		
			
				|  |  |  |  |         mma_C C; | 
		
	
		
			
				|  |  |  |  |         mma_B B; | 
		
	
		
			
				|  |  |  |  |         float dB[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_B::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j =    j0 + mma_B::get_j(l); | 
		
	
		
			
				|  |  |  |  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j = j0 + mma_C::get_j(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         C.mma_K8(A, B); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1( | 
		
	
		
			
				|  |  |  |  |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | 
		
	
	
		
			
				
					
					|  |  |  | @@ -400,7 +592,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_A_I16K8 mma_A; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_B_J8K8  mma_B; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_C_I16J8 mma_C; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int   * y_qs = (const int   *) y + 4; | 
		
	
		
			
				|  |  |  |  |     const half2 * y_ds = (const half2 *) y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     mma_A A; | 
		
	
		
			
				|  |  |  |  |     half2 dmA[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int i0 = threadIdx.y*mma_A::I; | 
		
	
		
			
				|  |  |  |  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_A::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i     =    i0 + mma_A::get_i(l); | 
		
	
		
			
				|  |  |  |  |         const int k     = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i = i0 + mma_C::get_i(2*l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | 
		
	
		
			
				|  |  |  |  |         mma_C C; | 
		
	
		
			
				|  |  |  |  |         mma_B B; | 
		
	
		
			
				|  |  |  |  |         half2 dsB[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_B::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j =    j0 + mma_B::get_j(l); | 
		
	
		
			
				|  |  |  |  |             const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j = j0 + mma_C::get_j(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         C.mma_K8(A, B); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const half2 dmA_dsB = dmA[l/2]*dsB[l%2]; | 
		
	
		
			
				|  |  |  |  |             sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0( | 
		
	
		
			
				|  |  |  |  |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | 
		
	
		
			
				|  |  |  |  |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | 
		
	
	
		
			
				
					
					|  |  |  | @@ -475,7 +730,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | 
		
	
		
			
				|  |  |  |  |     const int * __restrict__ y, float * __restrict__ sum, const int & k0) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_A_I16K8 mma_A; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_B_J8K8  mma_B; | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_C_I16J8 mma_C; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const float * x_df = (const float *) x_dm; | 
		
	
		
			
				|  |  |  |  |     const int   * y_qs = (const int   *) y + 4; | 
		
	
		
			
				|  |  |  |  |     const float * y_df = (const float *) y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     mma_A A; | 
		
	
		
			
				|  |  |  |  |     float dA[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int i0 = threadIdx.y*mma_A::I; | 
		
	
		
			
				|  |  |  |  |     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_A::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i = i0 + mma_A::get_i(l); | 
		
	
		
			
				|  |  |  |  |         const int k = k0 + mma_A::get_k(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         A.x[l] = x_ql[i*(WARP_SIZE + 1) + k]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |         const int i = i0 + mma_C::get_i(2*l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0]; | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { | 
		
	
		
			
				|  |  |  |  |         mma_C C; | 
		
	
		
			
				|  |  |  |  |         mma_B B; | 
		
	
		
			
				|  |  |  |  |         float dB[mma_C::ne/2]; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_B::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j = j0 + mma_B::get_j(l); | 
		
	
		
			
				|  |  |  |  |             const int k = k0 + mma_B::get_k(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             B.x[l] = y_qs[j*MMQ_TILE_Y_K + k]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne/2; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j = j0 + mma_C::get_j(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         C.mma_K8(A, B); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K( | 
		
	
		
			
				|  |  |  |  |     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | 
		
	
		
			
				|  |  |  |  |     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { | 
		
	
	
		
			
				
					
					|  |  |  | @@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template<int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { | 
		
	
		
			
				|  |  |  |  |         const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         if (j >= ne1) { | 
		
	
		
			
				|  |  |  |  |             return; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { | 
		
	
		
			
				|  |  |  |  |             const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             if (need_check && i >= ne0) { | 
		
	
		
			
				|  |  |  |  |                 continue; | 
		
	
		
			
				|  |  |  |  |             } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template<int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
		
			
				|  |  |  |  | static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { | 
		
	
		
			
				|  |  |  |  |     typedef mma_int_C_I16J8 mma_C; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int i0 = threadIdx.y*mma_C::I; | 
		
	
		
			
				|  |  |  |  |     static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int l = 0; l < mma_C::ne; ++l) { | 
		
	
		
			
				|  |  |  |  |             const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             if (j >= ne1) { | 
		
	
		
			
				|  |  |  |  |                 continue; | 
		
	
		
			
				|  |  |  |  |             } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             if (need_check && i >= ne0) { | 
		
	
		
			
				|  |  |  |  |                 continue; | 
		
	
		
			
				|  |  |  |  |             } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | // ------------------------------------------------------------------------------------------------------------------------------------- | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type> | 
		
	
	
		
			
				
					
					|  |  |  | @@ -998,35 +1367,65 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
		
			
				|  |  |  |  | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  | #ifdef INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #else | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #endif // INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
		
			
				|  |  |  |  | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  | #ifdef INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #else | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #endif // INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
		
			
				|  |  |  |  | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  | #ifdef INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #else | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #endif // INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
		
			
				|  |  |  |  | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  | #ifdef INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #else | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #endif // INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
		
			
				|  |  |  |  | struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  | #ifdef INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #else | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | #endif // INT8_MMA_AVAILABLE | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1034,6 +1433,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1041,6 +1441,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1048,6 +1449,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1055,6 +1457,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | template <int mmq_x, int mmq_y, int nwarps, bool need_check> | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1062,6 +1465,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> { | 
		
	
		
			
				|  |  |  |  |     static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ; | 
		
	
		
			
				|  |  |  |  |     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  |     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>; | 
		
	
		
			
				|  |  |  |  |     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>; | 
		
	
		
			
				|  |  |  |  | }; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | static int mmq_need_sum(const ggml_type type_x) { | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q( | 
		
	
		
			
				|  |  |  |  |     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr; | 
		
	
		
			
				|  |  |  |  |     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles; | 
		
	
		
			
				|  |  |  |  |     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot; | 
		
	
		
			
				|  |  |  |  |     constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q( | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; | 
		
	
		
			
				|  |  |  |  |     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |     for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q( | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { | 
		
	
		
			
				|  |  |  |  |         const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |         if (j >= ne1) { | 
		
	
		
			
				|  |  |  |  |             return; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | #pragma unroll | 
		
	
		
			
				|  |  |  |  |         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { | 
		
	
		
			
				|  |  |  |  |             const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             if (need_check && i >= ne0) { | 
		
	
		
			
				|  |  |  |  |                 continue; | 
		
	
		
			
				|  |  |  |  |             } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  |             dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     } | 
		
	
		
			
				|  |  |  |  |     write_back(sum, dst, ne0, ne1); | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | struct mmq_args { | 
		
	
	
		
			
				
					
					|  |  |  | @@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { | 
		
	
		
			
				|  |  |  |  |             launch_mul_mat_q<type,   8, 4>(args, stream); | 
		
	
		
			
				|  |  |  |  |             break; | 
		
	
		
			
				|  |  |  |  |         case  16: | 
		
	
		
			
				|  |  |  |  |             launch_mul_mat_q<type,  16, 8>(args, stream); | 
		
	
		
			
				|  |  |  |  |             launch_mul_mat_q<type,  16, 4>(args, stream); | 
		
	
		
			
				|  |  |  |  |             break; | 
		
	
		
			
				|  |  |  |  |         case  24: | 
		
	
		
			
				|  |  |  |  |             launch_mul_mat_q<type,  24, 8>(args, stream); | 
		
	
		
			
				|  |  |  |  |             launch_mul_mat_q<type,  24, 4>(args, stream); | 
		
	
		
			
				|  |  |  |  |             break; | 
		
	
		
			
				|  |  |  |  |         case  32: | 
		
	
		
			
				|  |  |  |  |             launch_mul_mat_q<type,  32, 8>(args, stream); | 
		
	
	
		
			
				
					
					|  |  |  |   |