#pragma once #include "common.cuh" #include "vecdotq.cuh" #include "mma.cuh" #include #include using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 #define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, float * __restrict__ dst, const int stride, const int i_max, const int j_max); enum mmq_q8_1_ds_layout { MMQ_Q8_1_DS_LAYOUT_D4, MMQ_Q8_1_DS_LAYOUT_DS4, MMQ_Q8_1_DS_LAYOUT_D2S6, }; struct block_q8_1_mmq { // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. // The y float data is first grouped as blocks of 128 values. // These blocks are then treated as individual data values and transposed. // // To avoid shared memory bank conflicts each block is padded with 16 bytes. // This padding is also used to store block scales/partial sums. // The scales multiplied with the quantized data are equal to the unquantized values. // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) // and are only needed for performance reasons. // // The exact data stored depends on the x data type. union { float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, // stored as d0,d1,s1,s2,s3,s4,s5 }; int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q5_0: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q5_1: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q8_0: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_MXFP4: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q2_K: return MMQ_Q8_1_DS_LAYOUT_D2S6; case GGML_TYPE_Q3_K: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_IQ1_S: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); break; } } struct tile_x_sizes { int qs; int dm; int sc; }; static int get_mmq_x_max_host(const int cc) { return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; #else MMQ_DP4A_MAX_BATCH_SIZE : 64; #endif // GGML_CUDA_FORCE_MMQ } static constexpr __device__ int get_mmq_x_max_device() { #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) return 128; #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; #else // defined(GGML_USE_HIP) #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #ifdef GGML_CUDA_FORCE_MMQ return 128; #else // GGML_CUDA_FORCE_MMQ return MMQ_DP4A_MAX_BATCH_SIZE; #endif // GGML_CUDA_FORCE_MMQ #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA return 64; #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); } static constexpr __device__ int get_mmq_y_device() { #if defined(GGML_USE_HIP) #if defined(RDNA1) return 64; #else return 128; #endif // defined RDNA1 #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA return 128; #else return 64; #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) } // Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. // The K dimension of the tiles has either, // 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), // 32 bit elements for the quantized data (does not include scales). // In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. // The final tile size in K direction is padded to avoid shared memory bank conflicts, // in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. #define MMQ_TILE_NE_K 32 #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; default: return tile_x_sizes{0, 0, 0}; } } #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; default: return 0; } } // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { if (amd_mfma_available(cc)) { return mmq_x >= 128 ? 32 : 16; } else if (turing_mma_available(cc) && mmq_x >= 48) { return 16; } else { return 8; } } #if defined(AMD_MFMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 128 ? 32 : 16; } #elif defined(TURING_MMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } #else static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { return 8; } #endif // AMD_MFMA_AVAILABLE #if defined(GGML_USE_HIP) static int mmq_get_nwarps_host(const int cc, const int warp_size) { return amd_mfma_available(cc) ? 8 : 256/warp_size; } #else static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { return 256/warp_size; } #endif // (GGML_USE_HIP) static constexpr __device__ int mmq_get_nwarps_device() { #if defined(AMD_MFMA_AVAILABLE) return 8; #else return 256/ggml_cuda_get_physical_warp_size(); #endif // AMD_MFMA_AVAILABLE } // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; const int kbx = txi / QI4_0; const int kqsx = txi % QI4_0; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); } const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; } sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; const int kbx = txi / QI4_1; const int kqsx = txi % QI4_1; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); } const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_1_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; } sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; const int kbx = txi / QI5_0; const int kqsx = txi % QI5_0; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); } const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; const int kbx = txi / QI5_1; const int kqsx = txi % QI5_1; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b4(bxi->qs, kqsx); const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); } const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp constexpr int threads_per_row = 32; constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; const int kbx = txi / QI8_0; const int kqsx = txi % QI8_0; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); } const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_mxfp4( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; const int kbx = txi / QI_MXFP4; const int kqsx = txi % QI_MXFP4; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx; const int aux_q4 = get_int_b1(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); const int k0 = kbx * (2 * QI_MXFP4) + kqsx; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); } const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; #else x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); } } } } template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; const int i0 = (threadIdx.y / ntx) * rows_per_warp; for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } else { dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C; mma(C, A[n], B); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_A::I + tile_C::get_i(l); const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; } } } } #else typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { tile_B B; float dB[tile_C::ne/2]; load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int j = j0 + tile_C::get_j(l); if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } else { dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C; mma(C, A[n][k01/QI8_0], B); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; } } } } #endif // defined(AMD_MFMA_AVAILABLE) } template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; const int i0 = (threadIdx.y / ntx) * rows_per_warp; for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C; mma(C, A[n], B); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_A::I + tile_C::get_i(l); float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; } } } } #else typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B; float2 dsB[tile_C::ne/2]; load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int j = j0 + tile_C::get_j(l); dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C; mma(C, A[n][k01/QI8_1], B); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; } } } } #endif // defined(AMD_MFMA_AVAILABLE) } // Used for Q3_K, IQ2_S, and IQ2_XS template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } // Used for Q3_K, IQ2_S, and IQ2_XS: template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; typedef tile<64, 2, int> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const int i0 = (threadIdx.y / ntx) * rows_per_warp; for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { const int k0 = k00 + k01; tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B[1]; load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C; mma(C, A[n], B[0]); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; } } } } #elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; typedef tile< 8, 4, int> tile_B; typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); tile_A A[ntx][8]; float dA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { const int k0 = k00 + k01; dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { tile_B B[2]; float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C[2]; mma(C[0], A[n][k01/4 + 0], B[0]); mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); } } } } #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); #pragma unroll for (int l = 0; l < QR2_K; ++l) { const int k = (kqsx/8)*32 + l*8 + kqsx % 8; const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; #ifdef FAST_FP16_AVAILABLE const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); #else const float2 bxi_dmf = __half22float2(bxi->dm); const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; float2 y_df[mmq_x/nwarps]; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); } #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 2; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } } // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. // As a workaround 2 separate loops are used instead. #pragma unroll for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 1; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } } } template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; typedef tile<64, 2, int> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; const int i0 = (threadIdx.y / ntx) * rows_per_warp; for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { const int k0 = k00 + k01; tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B[1]; load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); tile_C Cm; if (k01 >= MMQ_TILE_NE_K * 3/4) { tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; mma(Cm, A1, B[0]); } #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C Cd; mma(Cd, A[n], B[0]); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); float tmp = Cd.x[l]*dm.x; if (k01 >= MMQ_TILE_NE_K * 3/4) { tmp -= Cm.x[l]*dm.y; } sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; } } } } #elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; typedef tile< 8, 4, int> tile_B; typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); tile_A A[ntx][8]; float dA[ntx][tile_C::ne/2][8]; float mA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } } #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { const int k0 = k00 + k01; const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); dA[n][l][k01/(QI8_1/2)] = dm.x; mA[n][l][k01/(QI8_1/2)] = dm.y; } } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { float2 dB[tile_C::ne/2]; #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int j = j0 + tile_C::get_j(l); dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); } #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B[2]; // Here load_generic is faster than load_ldmatrix. load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); tile_C Cm[2]; if (k01 >= MMQ_TILE_NE_K * 3/4) { tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; mma(Cm[0], A1, B[0]); mma(Cm[1], A1, B[1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C Cd[2]; mma(Cd[0], A[n][k01/4 + 0], B[0]); mma(Cd[1], A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; if (k01 >= MMQ_TILE_NE_K * 3/4) { tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; } sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); } } } #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { float2 sB[tile_C::ne/2]; #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int j = j0 + tile_C::get_j(l); sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; } } } } #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); #pragma unroll for (int l = 0; l < QR3_K; ++l) { const int k = (kqsx/8)*32 + l*8 + kqsx % 8; const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } constexpr int rows_per_warp = warp_size / 4; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; if (need_check) { i = min(i, i_max); } const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; const int ksc = threadIdx.x % 4; const int ksc_low = ksc % (QI3_K/8); const int shift_low = 4 * (ksc / (QI3_K/8)); const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; const int ksc_high = QI3_K/8; const int shift_high = 2 * ksc; const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; const int sc = __vsubss4(sc_low | sc_high, 0x20202020); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; } #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; x_df[i] = bxi->d; } #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * x_sc = (const int *) x_df + txs.dm; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { // scale arrangement after the following two lines: // - ksc == 0: sc0, sc1, sc2, sc3 // - ksc == 1: sc4, sc5, sc6, sc7 // - ksc == 2: m0, m1, m2, m3 // - ksc == 3: m4, m5, m6, m7 return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits } template static __device__ __forceinline__ void load_tiles_q4_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, txi); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { #if defined(AMD_MFMA_AVAILABLE) // Need if on AMD instead of % because warp_size == 64 // This causes double work and throughput loss (MI300X) // H100 loses about 100 t/s with 'if' condition over '%' int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; if (i < mmq_y) { #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { #endif // defined(AMD_MFMA_AVAILABLE) if (need_check) { i = min(i, i_max); } const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % 2; const int sc32 = unpack_scales_q45_K(scales, ksc + 0); const int m32 = unpack_scales_q45_K(scales, ksc + 2); const uint8_t * sc8 = (const uint8_t *) &sc32; const uint8_t * m8 = (const uint8_t *) &m32; const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll for (int l = 0; l < sizeof(int); ++l) { x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); } } } #else #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; x_dm[i] = bxi->dm; } constexpr int rows_per_warp = warp_size / 4; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * x_sc = (const int *) x_dm + txs.dm; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; const int ky = QR5_K*txi; const int ql = get_int_b4(bxi->qs, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { #if defined(AMD_MFMA_AVAILABLE) // Need if on AMD instead of % because warp_size == 64 // This causes double work and throughput loss (MI300X) // H100 loses about 100 t/s with 'if' condition over '%' int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; if (i < mmq_y) { #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { #endif // defined(AMD_MFMA_AVAILABLE) if (need_check) { i = min(i, i_max); } const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % 2; const int sc32 = unpack_scales_q45_K(scales, ksc + 0); const int m32 = unpack_scales_q45_K(scales, ksc + 2); const uint8_t * sc8 = (const uint8_t *) &sc32; const uint8_t * m8 = (const uint8_t *) &m32; const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); } } } #else #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; x_dm[i] = bxi->dm; } constexpr int rows_per_warp = warp_size / 4; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + txs.qs; const int * x_sc = (const int *) x_dm + txs.dm; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; const int ql = get_int_b2(bxi->ql, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; const int kq0 = 2*txi - txi % (QI6_K/2) + 0; const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 4; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + txs.qs; const int * x_sc = (const int *) x_df + txs.dm; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; // #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; typedef tile<64, 2, int> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const int i0 = (threadIdx.y / ntx) * rows_per_warp; for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { const int k0 = k00 + k01; tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B[1]; load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C; mma(C, A[n], B[0]); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; } } } } #elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile< 8, 4, int> tile_B; typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); tile_A A[ntx][8]; int scA[ntx][tile_C::ne/2][8]; float dA[ntx][tile_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { const int k0 = k00 + k01; #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; const int8_t * sc = (const int8_t *) &sc_packed; #pragma unroll for (int ksc = 0; ksc < sizeof(int); ++ksc) { scA[n][l][k01/4 + ksc] = sc[ksc]; } } } #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; } } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { float tmp[ntx][tile_C::ne] = {{0.0f}}; #pragma unroll for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { tile_B B[2]; float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < tile_C::ne/2; ++l) { const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { tile_C C[2]; mma(C[0], A[n][k01/4 + 0], B[0]); mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; } } } #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; } } } #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); constexpr int nrows = warp_size / threads_per_row; const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; const int kbx = txi / QI4_NL; const int kqsx = txi % QI4_NL; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = kbx * (2 * QI4_NL) + kqsx; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); } const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); const uint8_t * aux8 = (const uint8_t *) &q2; const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); #pragma unroll for (int l = 0; l < QR2_XXS; ++l) { const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_iq2_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint16_t * q2 = (const uint16_t *) &q2_packed; #pragma unroll for (int l = 0; l < QR2_XS; ++l) { const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_iq2_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; const int qh = bxi->qh[kqsx]; const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; #pragma unroll for (int l = 0; l < QR2_S; ++l) { const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_iq3_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * q3 = (const uint8_t *) &q3_packed; const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); #pragma unroll for (int l = 0; l < QR3_XXS; ++l) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_iq3_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * qs = (const uint8_t *) &qs_packed; const int qh = bxi->qh[kqsx]; const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; #pragma unroll for (int l = 0; l < QR3_S; ++l) { const int2 grid_pos = make_int2( iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_iq1_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); } const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; const int qh = bxi->qh[kqsx]; #pragma unroll for (int l = 0; l < QR1_S/2; ++l) { const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = 8 * (kqsx / 4) + kqsx % 4; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 8; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); if (need_check) { i = min(i, i_max); } const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; const float d = __half2float(bxi->d); const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } template static __device__ __forceinline__ void mmq_write_back_dp4a( const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; if (j > j_max) { return; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int nwarps = mmq_get_nwarps_device(); #if defined(AMD_MFMA_AVAILABLE) constexpr int tileC_IJ = mmq_get_granularity_device(0); typedef tile tile_C; constexpr int rows_per_warp = granularity; #else typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; #endif // defined(AMD_MFMA_AVAILABLE) constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #else GGML_UNUSED(nwarps); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); if (j > j_max) { continue; } const int i = i0 + n*tile_C::I + tile_C::get_i(l); if (need_check && i > i_max) { continue; } dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; } } } } // ------------------------------------------------------------------------------------------------------------------------------------- template struct mmq_type_traits; template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template static __device__ __forceinline__ void mul_mat_q_process_tile( const char * __restrict__ x, const int offset_x, const int * __restrict__ y, const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int nwarps = mmq_get_nwarps_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; extern __shared__ int data_mul_mat_q[]; int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } } __syncthreads(); vec_dot(tile_x, tile_y, sum, 0); __syncthreads(); { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } } __syncthreads(); vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); __syncthreads(); } if (fixup) { write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); } else { write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); } } // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 template #if defined(GGML_USE_HIP) #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) #else __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) static __global__ void mul_mat_q( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ncols_max) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { NO_DEVICE_CODE; return; } constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. // For regular matrix multiplications this is never changed. // For MoE the correct indices are loaded from ids_dst. extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { const int j = j0 + threadIdx.y*warp_size + threadIdx.x; if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } ids_dst_shared[j] = j; } __syncthreads(); // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { const int wt = blockIdx.z / nchannels_y; const int zt = blockIdx.z - wt*nchannels_y; const int jt = blockIdx.y; const int it = blockIdx.x; // Defaults for regular matrix multiplication: int col_low = 0; int col_high = ncols_dst; int col_diff = ncols_dst; int offset_y = wt*stride_sample_y + zt*stride_channel_y; int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; if (ids_dst) { col_low = expert_bounds[zt + 0]; col_high = expert_bounds[zt + 1]; col_diff = col_high - col_low; offset_y = 0; offset_dst = 0; if (jt*mmq_x >= col_diff) { return; } // __syncthreads(); // There is no previous tile that could cause a race condition. #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { const int j = j0 + threadIdx.y*warp_size + threadIdx.x; if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; } __syncthreads(); } offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); offset_dst += it*mmq_y; const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. int kb0_start = kbc % blocks_per_ne00; int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { int tmp = kbc; const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); const int zt = tmp / (ntx*blocks_per_ne00); tmp -= zt * (ntx*blocks_per_ne00); const int jt = tmp / blocks_per_ne00; // Defaults for regular matrix multiplication: int col_low = 0; int col_high = ncols_dst; int col_diff = ncols_dst; int offset_y = wt*stride_sample_y + zt*stride_channel_y; int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; if (ids_dst) { col_low = expert_bounds[zt + 0]; col_high = expert_bounds[zt + 1]; col_diff = col_high - col_low; offset_y = 0; offset_dst = 0; if (jt*mmq_x >= col_diff) { kbc += blocks_per_ne00; kbc -= kbc % blocks_per_ne00; kb0_start = 0; kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); continue; } __syncthreads(); #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { const int j = j0 + threadIdx.y*warp_size + threadIdx.x; if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; } __syncthreads(); } offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); offset_dst += it*mmq_y; const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); kbc += blocks_per_ne00; kbc -= kbc % blocks_per_ne00; kb0_start = 0; kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); } if (kbc >= kbc_stop) { return; } int tmp = kbc; const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); const int zt = tmp / (ntx*blocks_per_ne00); tmp -= zt * (ntx*blocks_per_ne00); const int jt = tmp / blocks_per_ne00; // Defaults for regular matrix multiplication: int col_low = 0; int col_high = ncols_dst; int col_diff = ncols_dst; int offset_y = wt*stride_sample_y + zt*stride_channel_y; int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; if (ids_dst) { col_low = expert_bounds[zt + 0]; col_high = expert_bounds[zt + 1]; col_diff = col_high - col_low; offset_y = 0; offset_dst = 0; if (jt*mmq_x >= col_diff) { return; } // The memory layout for the fixup buffer is always contiguous, therefore reset ids: __syncthreads(); #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { const int j = j0 + threadIdx.y*warp_size + threadIdx.x; if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } ids_dst_shared[j] = j; } __syncthreads(); } offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); offset_dst += it*mmq_y; const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } template static __global__ void mul_mat_q_stream_k_fixup( const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; const int ntx = (ncols_max + mmq_x - 1) / mmq_x; const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; // kbc == k block continuous, current index in continuous ijk space. int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } bool any_fixup = false; // Iterate over previous blocks and sum up partial sums written to fixup buffer. // All CUDA blocks that get here must have a previous block that needs a fixup. int64_t bidx = bidx0 - 1; int64_t kbc_stop = kbc0; while(true) { int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; continue; } any_fixup = true; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } } // If this block started in a previous tile we are done and don't need to combine additional partial results. if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { break; } bidx--; kbc_stop = kbc; } if (!any_fixup) { return; } int tmp = kbc0; const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); const int zt = tmp / (ntx*blocks_per_ne00); tmp -= zt * (ntx*blocks_per_ne00); const int jt = tmp / blocks_per_ne00; if (!ids_dst) { const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; dst += offset_dst; const int i_max = nrows_x - it*mmq_y - 1; const int j_max = ncols_dst - jt*mmq_x - 1; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; if (j > j_max) { return; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } return; } __shared__ int ids_dst_shared[mmq_x]; const int col_low = expert_bounds[zt + 0]; const int col_high = expert_bounds[zt + 1]; const int col_diff = col_high - col_low; for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { ids_dst_shared[j] = ids_dst[col_low + j]; } __syncthreads(); const int offset_dst = it*mmq_y; dst += offset_dst; const int i_max = nrows_x - it*mmq_y - 1; const int j_max = col_diff - jt*mmq_x - 1; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; if (j > j_max) { return; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } struct mmq_args { const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; bool use_stream_k; int64_t ncols_max; }; template static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } template static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; const int warp_size = ggml_cuda_info().devices[id].warp_size; const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_y = get_mmq_y_host(cc); const dim3 block_dims(warp_size, nwarps, 1); const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; const int ntzw = args.nchannels_y * args.nsamples_y; const dim3 block_nums_xy_tiling(nty, ntx, ntzw); GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); const int channel_ratio = args.nchannels_y / args.nchannels_x; const int sample_ratio = args.nsamples_y / args.nsamples_x; if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, args.ncols_max); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, args.ncols_max); } return; } const dim3 block_nums_stream_k(nsm, 1, 1); const bool fixup_needed = ntx*nty*ntzw % nsm != 0; ggml_cuda_pool & pool = ctx.pool(id); ggml_cuda_pool_alloc tmp_fixup(pool); if (fixup_needed) { tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); } if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, args.ncols_max); if (!fixup_needed) { return; } mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, args.ncols_max); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, args.ncols_max); if (!fixup_needed) { return; } mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, args.ncols_max); } } template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const size_t smpbo = ggml_cuda_info().devices[id].smpbo; const int warp_size = ggml_cuda_info().devices[id].warp_size; const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); int mmq_x_best = 0; int ntiles_x_best = INT_MAX; for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { const int granularity = mmq_get_granularity_host(mmq_x, cc); if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { continue; } const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; if (ntiles_x < ntiles_x_best) { mmq_x_best = mmq_x; ntiles_x_best = ntiles_x; } } switch (mmq_x_best) { case 8: launch_mul_mat_q(ctx, args, stream); break; case 16: launch_mul_mat_q(ctx, args, stream); break; case 24: launch_mul_mat_q(ctx, args, stream); break; case 32: launch_mul_mat_q(ctx, args, stream); break; case 40: launch_mul_mat_q(ctx, args, stream); break; case 48: launch_mul_mat_q(ctx, args, stream); break; case 56: launch_mul_mat_q(ctx, args, stream); break; case 64: launch_mul_mat_q(ctx, args, stream); break; case 72: launch_mul_mat_q(ctx, args, stream); break; case 80: launch_mul_mat_q(ctx, args, stream); break; case 88: launch_mul_mat_q(ctx, args, stream); break; case 96: launch_mul_mat_q(ctx, args, stream); break; case 104: launch_mul_mat_q(ctx, args, stream); break; case 112: launch_mul_mat_q(ctx, args, stream); break; case 120: launch_mul_mat_q(ctx, args, stream); break; case 128: launch_mul_mat_q(ctx, args, stream); break; default: fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); GGML_ABORT("fatal error"); break; } } #define DECL_MMQ_CASE(type) \ template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); // ------------------------------------------------------------------------------------------------------------------------- void ggml_cuda_mul_mat_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);