mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (#15131)
* CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16
This commit is contained in:
		| @@ -233,9 +233,13 @@ typedef float2 dfloat2; | ||||
| #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) | ||||
|  | ||||
| #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING | ||||
| #define NEW_MMA_AVAILABLE | ||||
| #define TURING_MMA_AVAILABLE | ||||
| #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING | ||||
|  | ||||
| #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||
| #define AMPERE_MMA_AVAILABLE | ||||
| #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||
|  | ||||
| #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||
| #define CP_ASYNC_AVAILABLE | ||||
| #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||
| @@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) { | ||||
| } | ||||
|  | ||||
| // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. | ||||
| static bool new_mma_available(const int cc) { | ||||
| static bool turing_mma_available(const int cc) { | ||||
|     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; | ||||
| } | ||||
|  | ||||
| static bool ampere_mma_available(const int cc) { | ||||
|     return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; | ||||
| } | ||||
|  | ||||
| static bool cp_async_available(const int cc) { | ||||
|     return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; | ||||
| } | ||||
|   | ||||
| @@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | ||||
|         float        * const __restrict__ KQ_max, | ||||
|         float        * const __restrict__ KQ_rowsum, | ||||
|         const int kb0) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|     typedef fattn_mma_f16_config<DKQ, DV> c; | ||||
|  | ||||
| #ifdef CP_ASYNC_AVAILABLE | ||||
| @@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | ||||
|     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); | ||||
|     GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); | ||||
|     NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
| } | ||||
|  | ||||
| template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup> | ||||
| @@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | ||||
|         const int jt, | ||||
|         const int kb0_start, | ||||
|         const int kb0_stop) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||||
|  | ||||
|     typedef fattn_mma_f16_config<DKQ, DV> c; | ||||
| @@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | ||||
|     GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); | ||||
|     GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); | ||||
|     NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
| } | ||||
|  | ||||
| template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla> | ||||
| @@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16( | ||||
|                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||
|                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||
|                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||||
| #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     // Skip unused kernel variants for faster compilation: | ||||
|     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { | ||||
| @@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16( | ||||
|     GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); | ||||
|     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); | ||||
|     NO_DEVICE_CODE; | ||||
| #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) | ||||
| } | ||||
|  | ||||
| template <int DKQ, int DV, int ncols1, int ncols2> | ||||
|   | ||||
| @@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | ||||
|     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations | ||||
|     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; | ||||
|     const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); | ||||
|     const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && | ||||
|     const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && | ||||
|         (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); | ||||
|     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; | ||||
|     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { | ||||
| @@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | ||||
|     } | ||||
|  | ||||
|     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: | ||||
|     if (fp16_mma_available(cc) && !new_mma_available(cc)) { | ||||
|     if (fp16_mma_available(cc) && !turing_mma_available(cc)) { | ||||
|         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||||
|         return; | ||||
|     } | ||||
|   | ||||
| @@ -22,8 +22,9 @@ | ||||
| #include "ggml-cuda/fattn.cuh" | ||||
| #include "ggml-cuda/getrows.cuh" | ||||
| #include "ggml-cuda/im2col.cuh" | ||||
| #include "ggml-cuda/mmf.cuh" | ||||
| #include "ggml-cuda/mmq.cuh" | ||||
| #include "ggml-cuda/mmv.cuh" | ||||
| #include "ggml-cuda/mmvf.cuh" | ||||
| #include "ggml-cuda/mmvq.cuh" | ||||
| #include "ggml-cuda/norm.cuh" | ||||
| #include "ggml-cuda/opt-step-adamw.cuh" | ||||
| @@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||
|     const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE | ||||
|         && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; | ||||
|  | ||||
|     bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) | ||||
|     bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) | ||||
|         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; | ||||
|     bool use_mul_mat_f     = !ggml_is_quantized(src0->type) | ||||
|         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; | ||||
|     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear | ||||
|         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 | ||||
| @@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||
|             } | ||||
|  | ||||
|             const int cc            = ggml_cuda_info().devices[id].cc; | ||||
|             const int warp_size     = ggml_cuda_info().devices[id].warp_size; | ||||
|             use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); | ||||
|             use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); | ||||
|             use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); | ||||
|             use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); | ||||
|             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc); | ||||
|         } | ||||
|     } else { | ||||
|         const int cc            = ggml_cuda_info().devices[ctx.device].cc; | ||||
|         const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size; | ||||
|         use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); | ||||
|         use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); | ||||
|         use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); | ||||
|         use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); | ||||
|         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc); | ||||
|     } | ||||
|  | ||||
| @@ -2053,10 +2060,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||
|     bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); | ||||
|     bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32; | ||||
|  | ||||
|     if (!split && use_mul_mat_vec) { | ||||
|     if (!split && use_mul_mat_vec_f) { | ||||
|         // the custom F16 vector kernel can be used over batched cuBLAS GEMM | ||||
|         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) | ||||
|         ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); | ||||
|         ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst); | ||||
|     } else if (!split && use_mul_mat_f) { | ||||
|         ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst); | ||||
|     } else if (!split && use_mul_mat_vec_q) { | ||||
|         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); | ||||
|     } else if (!split && use_mul_mat_q) { | ||||
| @@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||
|         && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { | ||||
|         // general KQ + KQV multi-batch without FlashAttention | ||||
|         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); | ||||
|     } else if (use_mul_mat_vec) { | ||||
|         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); | ||||
|     } else if (use_mul_mat_vec_f) { | ||||
|         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr); | ||||
|     } else if (use_mul_mat_vec_q) { | ||||
|         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); | ||||
|     } else if (use_mul_mat_q) { | ||||
| @@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * | ||||
|             if (ggml_is_quantized(src0->type)) { | ||||
|                 ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); | ||||
|             } else { | ||||
|                 ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst); | ||||
|                 ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); | ||||
|             } | ||||
|             return; | ||||
|         } | ||||
| @@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | ||||
| #endif // FLASH_ATTN_AVAILABLE | ||||
|             if (op->src[1]->ne[0] != op->src[2]->ne[0]) { | ||||
|                 const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; | ||||
|                 if (!new_mma_available(cc)) { | ||||
|                 if (!turing_mma_available(cc)) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; | ||||
|   | ||||
| @@ -23,13 +23,13 @@ | ||||
| static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { | ||||
|     int ret = 0; | ||||
|  | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" | ||||
|         : "=r"(ret) : "r"(x)); | ||||
| #else | ||||
|     GGML_UNUSED(x); | ||||
|     NO_DEVICE_CODE; | ||||
| #endif // defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(TURING_MMA_AVAILABLE) | ||||
|     return ret; | ||||
| } | ||||
|  | ||||
| @@ -167,6 +167,38 @@ namespace ggml_cuda_mma { | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     template <int I_, int J_> | ||||
|     struct tile<I_, J_, nv_bfloat162> { | ||||
|         static constexpr int I  = I_; | ||||
|         static constexpr int J  = J_; | ||||
|         static constexpr int ne = I * J / WARP_SIZE; | ||||
|         nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; | ||||
|  | ||||
|         static __device__ __forceinline__ int get_i(const int l) { | ||||
|             if constexpr (I == 8 && J == 8) { | ||||
|                 return threadIdx.x / 4; | ||||
|             } else if constexpr (I == 16 && J == 4) { | ||||
|                 return l * 8 + threadIdx.x / 4; | ||||
|             } else if constexpr (I == 16 && J == 8) { | ||||
|                 return (l % 2) * 8 + threadIdx.x / 4; | ||||
|             } else { | ||||
|                 static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         static __device__ __forceinline__ int get_j(const int l) { | ||||
|             if constexpr (I == 8 && J == 8) { | ||||
|                 return l * 4 + threadIdx.x % 4; | ||||
|             } else if constexpr (I == 16 && J == 4) { | ||||
|                 return threadIdx.x % 4; | ||||
|             } else if constexpr (I == 16 && J == 8) { | ||||
|                 return (l / 2) * 4 + threadIdx.x % 4; | ||||
|             } else { | ||||
|                 static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||||
|             } | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     template <int I, int J> | ||||
|     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { | ||||
|         tile<I, J/2, half2> ret; | ||||
| @@ -209,7 +241,7 @@ namespace ggml_cuda_mma { | ||||
|     template <typename T> | ||||
|     static __device__ __forceinline__ void load_ldmatrix( | ||||
|             tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|         int * xi = (int *) t.x; | ||||
|         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; | ||||
|         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" | ||||
| @@ -217,13 +249,13 @@ namespace ggml_cuda_mma { | ||||
|             : "l"(xs)); | ||||
| #else | ||||
|         load_generic(t, xs0, stride); | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     template <typename T> | ||||
|     static __device__ __forceinline__ void load_ldmatrix( | ||||
|             tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|         int * xi = (int *) t.x; | ||||
|         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; | ||||
|         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" | ||||
| @@ -232,13 +264,13 @@ namespace ggml_cuda_mma { | ||||
| #else | ||||
|         load_generic(xs0, stride); | ||||
|         GGML_UNUSED(t); | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     template <typename T> | ||||
|     static __device__ __forceinline__ void load_ldmatrix( | ||||
|             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { | ||||
| #if defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(TURING_MMA_AVAILABLE) | ||||
|         int * xi = (int * ) t.x; | ||||
|         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); | ||||
|         asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" | ||||
| @@ -246,13 +278,13 @@ namespace ggml_cuda_mma { | ||||
|             : "l"(xs)); | ||||
| #else | ||||
|         load_generic(t, xs0, stride); | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     template <typename T> | ||||
|     static __device__ __forceinline__ void load_ldmatrix_trans( | ||||
|             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|         int * xi = (int * ) t.x; | ||||
|         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); | ||||
|         asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" | ||||
| @@ -263,12 +295,12 @@ namespace ggml_cuda_mma { | ||||
|         GGML_UNUSED(xs0); | ||||
|         GGML_UNUSED(stride); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
| #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||
|         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" | ||||
|             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) | ||||
| @@ -287,12 +319,12 @@ namespace ggml_cuda_mma { | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
| #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||
|         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||
|             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) | ||||
| @@ -317,12 +349,12 @@ namespace ggml_cuda_mma { | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|         const int * Axi = (const int *) A.x; | ||||
|         const int * Bxi = (const int *) B.x; | ||||
|         int       * Dxi = (int       *) D.x; | ||||
| @@ -344,12 +376,12 @@ namespace ggml_cuda_mma { | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|         const int * Axi = (const int *) A.x; | ||||
|         const int * Bxi = (const int *) B.x; | ||||
|         int       * Dxi = (int       *) D.x; | ||||
| @@ -380,12 +412,29 @@ namespace ggml_cuda_mma { | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { | ||||
| #ifdef AMPERE_MMA_AVAILABLE | ||||
|         const int * Axi = (const int *) A.x; | ||||
|         const int * Bxi = (const int *) B.x; | ||||
|         int       * Dxi = (int       *) D.x; | ||||
|         asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||
|             : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) | ||||
|             : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); | ||||
| #else | ||||
|         GGML_UNUSED(D); | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // AMPERE_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|         const int * Axi = (const int *) A.x; | ||||
|         const int * Bxi = (const int *) B.x; | ||||
|         int       * Dxi = (int       *) D.x; | ||||
| @@ -407,12 +456,29 @@ namespace ggml_cuda_mma { | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) { | ||||
| #ifdef AMPERE_MMA_AVAILABLE | ||||
|         const int * Axi = (const int *) A.x; | ||||
|         const int * Bxi = (const int *) B.x; | ||||
|         int       * Dxi = (int       *) D.x; | ||||
|         asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||
|             : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) | ||||
|             : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); | ||||
| #else | ||||
|         GGML_UNUSED(D); | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // AMPERE_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|             tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { | ||||
| #ifdef NEW_MMA_AVAILABLE | ||||
| #ifdef TURING_MMA_AVAILABLE | ||||
|         const int * Axi = (const int *) A.x; | ||||
|         const int * Bxi = (const int *) B.x; | ||||
|         int       * Dxi = (int       *) D.x; | ||||
| @@ -443,7 +509,7 @@ namespace ggml_cuda_mma { | ||||
|         GGML_UNUSED(A); | ||||
|         GGML_UNUSED(B); | ||||
|         NO_DEVICE_CODE; | ||||
| #endif // NEW_MMA_AVAILABLE | ||||
| #endif // TURING_MMA_AVAILABLE | ||||
|     } | ||||
|  | ||||
|     static __device__ __forceinline__ void mma( | ||||
|   | ||||
							
								
								
									
										431
									
								
								ggml/src/ggml-cuda/mmf.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										431
									
								
								ggml/src/ggml-cuda/mmf.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,431 @@ | ||||
| #include "ggml.h" | ||||
| #include "common.cuh" | ||||
| #include "mma.cuh" | ||||
| #include "mmf.cuh" | ||||
|  | ||||
| using namespace ggml_cuda_mma; | ||||
|  | ||||
| #define MMF_ROWS_PER_BLOCK 32 | ||||
|  | ||||
| template <typename T, int rows_per_block, int cols_per_block, int nwarps> | ||||
| __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) | ||||
| static __global__ void mul_mat_f( | ||||
|         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, | ||||
|         const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, | ||||
|         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, | ||||
|         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { | ||||
| #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||||
|     typedef tile<16, 8, T>     tile_A; | ||||
|     typedef tile< 8, 8, T>     tile_B; | ||||
|     typedef tile<16, 8, float> tile_C; | ||||
|  | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|     constexpr int tile_k_padded = warp_size + 4; | ||||
|     constexpr int ntA = rows_per_block / tile_A::I; | ||||
|     constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; | ||||
|  | ||||
|     const int row0        = blockIdx.x * rows_per_block; | ||||
|     const int channel_dst = blockIdx.y; | ||||
|     const int channel_x   = channel_dst / channel_ratio; | ||||
|     const int channel_y   = channel_dst; | ||||
|     const int sample_dst  = blockIdx.z; | ||||
|     const int sample_x    = sample_dst / sample_ratio; | ||||
|     const int sample_y    = sample_dst; | ||||
|  | ||||
|     x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row0*stride_row ; | ||||
|     y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y; | ||||
|     dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; | ||||
|  | ||||
|     const float2 * y2 = (const float2 *) y; | ||||
|  | ||||
|     extern __shared__ char data_mmv[]; | ||||
|  | ||||
|     tile_C C[ntA][ntB]; | ||||
|  | ||||
|     T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); | ||||
|  | ||||
|     for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { | ||||
|         tile_A A[ntA][warp_size / tile_A::J]; | ||||
| #pragma unroll | ||||
|         for (int itA = 0; itA < ntA; ++itA) { | ||||
| #pragma unroll | ||||
|             for (int i = 0; i < tile_A::I; ++i) { | ||||
|                 tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col]; | ||||
|             } | ||||
| #pragma unroll | ||||
|             for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { | ||||
|                 load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); | ||||
|             } | ||||
|         } | ||||
|  | ||||
| #pragma unroll | ||||
|         for (int itB = 0; itB < ntB; ++itB) { | ||||
|             if constexpr (std::is_same_v<T, float>) { | ||||
| #pragma unroll | ||||
|                 for (int j0 = 0; j0 < tile_B::I; ++j0) { | ||||
|                     const int j = j0 + itB*tile_B::I; | ||||
|  | ||||
|                     tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; | ||||
|                 } | ||||
|             } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) { | ||||
| #pragma unroll | ||||
|                 for (int j0 = 0; j0 < tile_B::I; ++j0) { | ||||
|                     const int j = j0 + itB*tile_B::I; | ||||
|  | ||||
|                     const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); | ||||
|                     tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; | ||||
|                 } | ||||
|             } else { | ||||
|                 static_assert(std::is_same_v<T, void>, "unsupported type"); | ||||
|             } | ||||
| #pragma unroll | ||||
|             for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { | ||||
|                 tile_B B; | ||||
|                 load_ldmatrix(B, tile_xy + k0, tile_k_padded); | ||||
| #pragma unroll | ||||
|                 for (int itA = 0; itA < ntA; ++itA) { | ||||
|                     mma(C[itA][itB], A[itA][k0/tile_B::J], B); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     float * buf_iw = (float *) data_mmv; | ||||
|     constexpr int kiw = nwarps*rows_per_block + 4; | ||||
|  | ||||
|     if (nwarps > 1) { | ||||
|         __syncthreads(); | ||||
|     } | ||||
| #pragma unroll | ||||
|     for (int itB = 0; itB < ntB; ++itB) { | ||||
| #pragma unroll | ||||
|         for (int itA = 0; itA < ntA; ++itA) { | ||||
| #pragma unroll | ||||
|             for (int l = 0; l < tile_C::ne; ++l) { | ||||
|                 const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); | ||||
|                 const int j = itB*tile_C::J + tile_C::get_j(l); | ||||
|                 buf_iw[j*kiw + i] = C[itA][itB].x[l]; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (nwarps > 1) { | ||||
|         __syncthreads(); | ||||
|     } | ||||
|  | ||||
| #pragma unroll | ||||
|     for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { | ||||
|         const int j = j0 + threadIdx.y; | ||||
|  | ||||
|         if (j0 + nwarps > cols_per_block && j >= cols_per_block) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         float sum = 0.0f; | ||||
|         static_assert(rows_per_block == warp_size, "need loop/check"); | ||||
| #pragma unroll | ||||
|         for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { | ||||
|             const int i = i0 + threadIdx.x; | ||||
|  | ||||
|             sum += buf_iw[j*kiw + i]; | ||||
|         } | ||||
|         dst[j*stride_col_dst + row0 + threadIdx.x] = sum; | ||||
|     } | ||||
| #else | ||||
|     NO_DEVICE_CODE; | ||||
|     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst); | ||||
|     GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst); | ||||
|     GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst); | ||||
|     GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst); | ||||
| #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||||
| } | ||||
|  | ||||
| template <typename T, int cols_per_block> | ||||
| static void mul_mat_f_cuda( | ||||
|         const T * x, const float * y, const int32_t * ids, float * dst, | ||||
|         const int64_t ncols_x, const int64_t nrows_x, | ||||
|         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, | ||||
|         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, | ||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, | ||||
|         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, | ||||
|         cudaStream_t stream) { | ||||
|     typedef tile<16, 8, T>     tile_A; | ||||
|     typedef tile< 8, 8, T>     tile_B; | ||||
|     typedef tile<16, 8, float> tile_C; | ||||
|  | ||||
|     GGML_ASSERT(!ids && "mul_mat_id not implemented"); | ||||
|  | ||||
|     GGML_ASSERT(ncols_x      % 2 == 0); | ||||
|     GGML_ASSERT(stride_row   % 2 == 0); | ||||
|     GGML_ASSERT(stride_col_y % 2 == 0); | ||||
|     GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); | ||||
|     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0); | ||||
|     const int64_t channel_ratio = nchannels_dst / nchannels_x; | ||||
|     const int64_t sample_ratio  = nsamples_dst  / nsamples_x; | ||||
|  | ||||
|     const int device = ggml_cuda_get_device(); | ||||
|     const int warp_size = ggml_cuda_info().devices[device].warp_size; | ||||
|  | ||||
|     int64_t nwarps_best     = 1; | ||||
|     int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2); | ||||
|     int64_t max_block_size  = 256; | ||||
|     for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { | ||||
|         const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); | ||||
|         if (niter < niter_best) { | ||||
|             niter_best  = niter; | ||||
|             nwarps_best = nwarps; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; | ||||
|     const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; | ||||
|     const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; | ||||
|     const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); | ||||
|     const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); | ||||
|     const dim3 block_dims(warp_size, nwarps_best, 1); | ||||
|     switch (nwarps_best) { | ||||
|         case 1: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case 2: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case 3: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case 4: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case 5: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case 6: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case 7: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case 8: { | ||||
|             mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         default: { | ||||
|             GGML_ABORT("fatal error"); | ||||
|         } break; | ||||
|     } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| static void mul_mat_f_switch_cols_per_block( | ||||
|         const T * x, const float * y, const int32_t * ids, float * dst, | ||||
|         const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, | ||||
|         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, | ||||
|         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, | ||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, | ||||
|         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, | ||||
|         cudaStream_t stream) { | ||||
|     switch (ncols_dst) { | ||||
|         case  1: { | ||||
|             mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  2: { | ||||
|             mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  3: { | ||||
|             mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  4: { | ||||
|             mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  5: { | ||||
|             mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  6: { | ||||
|             mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  7: { | ||||
|             mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  8: { | ||||
|             mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case  9: { | ||||
|             mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case 10: { | ||||
|             mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case 11: { | ||||
|             mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case 12: { | ||||
|             mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case 13: { | ||||
|             mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case 14: { | ||||
|             mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case 15: { | ||||
|             mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         case 16: { | ||||
|             mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream); | ||||
|         } break; | ||||
|         default: { | ||||
|             GGML_ABORT("fatal error"); | ||||
|         } break; | ||||
|     } | ||||
| } | ||||
|  | ||||
| void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(        src1->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32); | ||||
|     GGML_ASSERT(         dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|     GGML_TENSOR_BINARY_OP_LOCALS; | ||||
|  | ||||
|     const size_t ts_src0 = ggml_type_size(src0->type); | ||||
|     const size_t ts_src1 = ggml_type_size(src1->type); | ||||
|     const size_t ts_dst  = ggml_type_size(dst->type); | ||||
|  | ||||
|     GGML_ASSERT(ne13 == ne3); | ||||
|  | ||||
|     GGML_ASSERT(        nb00       == ts_src0); | ||||
|     GGML_ASSERT(        nb10       == ts_src1); | ||||
|     GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); | ||||
|     GGML_ASSERT(        nb0        == ts_dst); | ||||
|  | ||||
|     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||||
|     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; | ||||
|  | ||||
|     const float   * src1_d =       (const float   *) src1->data; | ||||
|     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr; | ||||
|     float         *  dst_d =       (float         *)  dst->data; | ||||
|  | ||||
|     const int64_t s01 = src0->nb[1] / ts_src0; | ||||
|     const int64_t s11 = src1->nb[1] / ts_src1; | ||||
|     const int64_t s1  =  dst->nb[1] / ts_dst; | ||||
|     const int64_t s02 = src0->nb[2] / ts_src0; | ||||
|     const int64_t s12 = src1->nb[2] / ts_src1; | ||||
|     const int64_t s2  =  dst->nb[2] / ts_dst; | ||||
|     const int64_t s03 = src0->nb[3] / ts_src0; | ||||
|     const int64_t s13 = src1->nb[3] / ts_src1; | ||||
|     const int64_t s3  =  dst->nb[3] / ts_dst; | ||||
|  | ||||
|     // For MUL_MAT_ID the memory layout is different than for MUL_MAT: | ||||
|     const int64_t ncols_dst          = ids ? ne2  : ne1; | ||||
|     const int64_t nchannels_y        = ids ? ne11 : ne12; | ||||
|     const int64_t nchannels_dst      = ids ? ne1  : ne2; | ||||
|     const int64_t stride_channel_dst = ids ? s1   : s2; | ||||
|     const int64_t stride_channel_y   = ids ? s11  : s12; | ||||
|  | ||||
|     GGML_ASSERT(!ids || ncols_dst == 1); | ||||
|  | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F32: { | ||||
|             const float * src0_d = (const float *) src0->data; | ||||
|             constexpr int vals_per_T = 1; | ||||
|             mul_mat_f_switch_cols_per_block( | ||||
|                 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, | ||||
|                 ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, | ||||
|                 ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream()); | ||||
|         } break; | ||||
|         case GGML_TYPE_F16: { | ||||
|             const half2 * src0_d = (const half2 *) src0->data; | ||||
|             constexpr int vals_per_T = 2; | ||||
|             mul_mat_f_switch_cols_per_block( | ||||
|                 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, | ||||
|                 ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, | ||||
|                 ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream()); | ||||
|         } break; | ||||
|         case GGML_TYPE_BF16: { | ||||
|             const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; | ||||
|             constexpr int vals_per_T = 2; | ||||
|             mul_mat_f_switch_cols_per_block( | ||||
|                 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, | ||||
|                 ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, | ||||
|                 ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream()); | ||||
|         } break; | ||||
|         default: | ||||
|             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); | ||||
|     } | ||||
| } | ||||
|  | ||||
| bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { | ||||
|     if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { | ||||
|         return false; | ||||
|     } | ||||
|     if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { | ||||
|         return false; | ||||
|     } | ||||
|     if (ne11 > 16) { | ||||
|         return false; | ||||
|     } | ||||
|     switch (type) { | ||||
|         case GGML_TYPE_F32: | ||||
|             return ampere_mma_available(cc); | ||||
|         case GGML_TYPE_F16: | ||||
|             return turing_mma_available(cc); | ||||
|         case GGML_TYPE_BF16: | ||||
|             return ampere_mma_available(cc); | ||||
|         default: | ||||
|             return false; | ||||
|     } | ||||
| } | ||||
							
								
								
									
										5
									
								
								ggml/src/ggml-cuda/mmf.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								ggml/src/ggml-cuda/mmf.cuh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| #include "common.cuh" | ||||
|  | ||||
| void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); | ||||
|  | ||||
| bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); | ||||
| @@ -310,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { | ||||
|         return false; | ||||
|     } | ||||
|  | ||||
|     if (new_mma_available(cc)) { | ||||
|     if (turing_mma_available(cc)) { | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -92,7 +92,7 @@ struct tile_x_sizes { | ||||
| }; | ||||
|  | ||||
| static int get_mmq_x_max_host(const int cc) { | ||||
|     return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 : | ||||
|     return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : | ||||
|         GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? | ||||
| #ifdef GGML_CUDA_FORCE_MMQ | ||||
|             128                     : 64; | ||||
| @@ -102,9 +102,9 @@ static int get_mmq_x_max_host(const int cc) { | ||||
| } | ||||
|  | ||||
| static constexpr __device__ int get_mmq_x_max_device() { | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     return 128; | ||||
| #else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
| #if defined(GGML_USE_HIP) | ||||
|     return 64; | ||||
| @@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() { | ||||
| #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA | ||||
|  | ||||
| #endif // defined(GGML_USE_HIP) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
| } | ||||
|  | ||||
| static int get_mmq_y_host(const int cc) { | ||||
| @@ -233,7 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { | ||||
| static int mmq_get_granularity_host(const int mmq_x, const int cc) { | ||||
|     if (amd_mfma_available(cc)) { | ||||
|         return mmq_x >= 128 ? 32 : 16; | ||||
|     } else if (new_mma_available(cc) && mmq_x >= 48) { | ||||
|     } else if (turing_mma_available(cc) && mmq_x >= 48) { | ||||
|         return 16; | ||||
|     } else { | ||||
|         return 8; | ||||
| @@ -244,7 +244,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) { | ||||
| static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { | ||||
|     return mmq_x >= 128 ? 32 : 16; | ||||
| } | ||||
| #elif defined(NEW_MMA_AVAILABLE) | ||||
| #elif defined(TURING_MMA_AVAILABLE) | ||||
| static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { | ||||
|     return mmq_x >= 48 ? 16 : 8; | ||||
| } | ||||
| @@ -279,14 +279,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -305,12 +305,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; | ||||
|         const int qs0 = get_int_b2(bxi->qs, kqsx); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); | ||||
| #else | ||||
|         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; | ||||
| @@ -327,11 +327,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d; | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -382,14 +382,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -408,12 +408,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; | ||||
|         const int qs0 = get_int_b4(bxi->qs, kqsx); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; | ||||
| #else | ||||
|         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; | ||||
| @@ -430,11 +430,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm; | ||||
| #else | ||||
|         x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -485,14 +485,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -527,13 +527,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28 | ||||
|         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16 | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0]     = qs0; | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; | ||||
| @@ -550,11 +550,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d; | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -563,14 +563,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -603,13 +603,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20 | ||||
|         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28 | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0]     = qs0; | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; | ||||
| @@ -626,11 +626,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm; | ||||
| #else | ||||
|         x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -639,14 +639,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp | ||||
|     constexpr int threads_per_row = 32; | ||||
| @@ -665,13 +665,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx); | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx); | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; | ||||
| @@ -688,11 +688,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0                 + kbxd] = bxi->d; | ||||
| #else | ||||
|         x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -701,14 +701,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -730,13 +730,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); | ||||
|         const int k0 = kbx * (2 * QI_MXFP4) + kqsx; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0]        = v.x; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]        = v.x; | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; | ||||
| @@ -753,11 +753,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_1                 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #elif defined(NEW_MMA_AVAILABLE) | ||||
| #elif defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     typedef tile<16, 4, int> tile_A; | ||||
|     typedef tile<16, 8, int> tile_A_8; | ||||
| @@ -1264,14 +1264,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); | ||||
|     constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; | ||||
| @@ -1295,11 +1295,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|             const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|  | ||||
|         const int sc_m = bxi->scales[kqsx]; | ||||
| @@ -1310,11 +1310,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); | ||||
| #endif // FAST_FP16_AVAILABLE | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; | ||||
| #else | ||||
|         x_dm[i*(MMQ_TILE_NE_K + 1)   + kqsx] = x_dm_ik; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -1452,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #elif defined(NEW_MMA_AVAILABLE) | ||||
| #elif defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     typedef tile<16, 4, int> tile_A; | ||||
|     typedef tile<16, 8, int> tile_A_8; | ||||
| @@ -1582,7 +1582,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
| @@ -1590,7 +1590,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
|     int   * x_sc = (int   *) (x_df + txs.dm); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -1618,11 +1618,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|             const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -1649,7 +1649,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const int sc = __vsubss4(sc_low | sc_high, 0x20202020); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         const int8_t * sc8 = (const int8_t *) ≻ | ||||
|         const float d = bxi->d; | ||||
|  | ||||
| @@ -1659,10 +1659,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         } | ||||
| #else | ||||
|         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
| #if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) | ||||
| #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) | ||||
| #pragma unroll | ||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { | ||||
|         int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; | ||||
| @@ -1675,7 +1675,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         x_df[i] = bxi->d; | ||||
|     } | ||||
| #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) | ||||
| #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) | ||||
| } | ||||
|  | ||||
| template <int mmq_x, int mmq_y> | ||||
| @@ -1728,7 +1728,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); | ||||
| #else | ||||
| @@ -1736,7 +1736,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + txs.qs); | ||||
|     int   * x_sc = (int   *) (x_dm + txs.dm); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -1753,15 +1753,15 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; | ||||
|         const int qs0 = get_int_b4(bxi->qs, txi); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; | ||||
| #else | ||||
|         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     constexpr int rows_per_warp = warp_size / 2; | ||||
| #pragma unroll | ||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { | ||||
| @@ -1829,7 +1829,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; | ||||
|     } | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
| } | ||||
|  | ||||
| template <int mmq_x, int mmq_y> | ||||
| @@ -1872,7 +1872,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
| @@ -1880,7 +1880,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_dm = (half2 *) (x_qs + txs.qs); | ||||
|     int   * x_sc = (int   *) (x_dm + txs.dm); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -1908,16 +1908,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; | ||||
|         const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     constexpr int rows_per_warp = warp_size / 2; | ||||
| #pragma unroll | ||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { | ||||
| @@ -1986,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; | ||||
|     } | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
| } | ||||
|  | ||||
| template <int mmq_x, int mmq_y> | ||||
| @@ -2029,7 +2029,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
|     int   * x_sc = (int   *) (x_df + MMQ_TILE_NE_K/QI6_K); | ||||
| @@ -2038,7 +2038,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
|     int   * x_sc = (int   *) (x_df + txs.dm); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2065,13 +2065,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const int kq0 = 2*txi - txi % (QI6_K/2) + 0; | ||||
|         const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
| #pragma unroll | ||||
| @@ -2084,11 +2084,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q6_K]           = bxi->d; | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int rows_per_warp = warp_size / 4; | ||||
| @@ -2102,11 +2102,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); | ||||
| #else | ||||
|         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2199,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #elif defined(NEW_MMA_AVAILABLE) | ||||
| #elif defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     typedef tile<16, 4, int> tile_A; | ||||
|     typedef tile< 8, 4, int> tile_B; | ||||
| @@ -2311,14 +2311,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2340,13 +2340,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); | ||||
|         const int k0 = kbx * (2 * QI4_NL) + kqsx; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0]      = v.x; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]      = v.x; | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; | ||||
| @@ -2363,11 +2363,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|  | ||||
|         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = __half2float(bxi->d); | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2376,14 +2376,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2414,22 +2414,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|             const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); | ||||
|             const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|  | ||||
|         const int ls = aux32 >> 28; | ||||
|         const float d = bxi->d; | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4; | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2438,14 +2438,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2472,24 +2472,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|             const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); | ||||
|             const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|  | ||||
|         const int ls = bxi->scales[kqsx]; | ||||
|         const float d = bxi->d; | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4; | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4; | ||||
| #else | ||||
|         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4; | ||||
|         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2498,14 +2498,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2539,24 +2539,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|             const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); | ||||
|             const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|  | ||||
|         const int ls = bxi->scales[kqsx]; | ||||
|         const float d = bxi->d; | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4; | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4; | ||||
| #else | ||||
|         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4; | ||||
|         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2565,14 +2565,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2601,22 +2601,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|             const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); | ||||
|             const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|  | ||||
|         const int ls = aux32 >> 28; | ||||
|         const float d = bxi->d; | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = (ls*d + d/2)/2; | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = (ls*d + d/2)/2; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2625,14 +2625,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2668,22 +2668,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|             const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); | ||||
|             const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|  | ||||
|         const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); | ||||
|         const float d = bxi->d; | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = ls*d; | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = ls*d; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2692,14 +2692,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     half2 * x_ds = (half2 *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2727,23 +2727,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|             const int grid0 = (grid >> 0) & 0x0F0F0F0F; | ||||
|             const int grid1 = (grid >> 4) & 0x0F0F0F0F; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; | ||||
|             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; | ||||
| #else | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; | ||||
|             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         } | ||||
|  | ||||
|         const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); | ||||
|         const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_ds[i*MMQ_MMA_TILE_X_K_Q8_1     + kqsx] = make_half2(d1q, d1q*delta); | ||||
| #else | ||||
|         x_ds[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2752,14 +2752,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|     constexpr int nwarps = mmq_get_nwarps_device(); | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); | ||||
| #else | ||||
|     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); | ||||
|     int   * x_qs = (int   *)  x_tile; | ||||
|     float * x_df = (float *) (x_qs + txs.qs); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); | ||||
|     constexpr int nrows = warp_size / threads_per_row; | ||||
| @@ -2779,13 +2779,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); | ||||
|         const int k0 = 8 * (kqsx / 4) + kqsx % 4; | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; | ||||
|         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; | ||||
| #else | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; | ||||
|         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
|  | ||||
|     constexpr int rows_per_warp = warp_size / 8; | ||||
| @@ -2804,11 +2804,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | ||||
|         const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | ||||
|             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|         x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + threadIdx.x % 8] = d * (ls - 32); | ||||
| #else | ||||
|         x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -2859,9 +2859,9 @@ static __device__ __forceinline__ void mmq_write_back_mma( | ||||
|     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. | ||||
|  | ||||
|     const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); | ||||
| #if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) | ||||
| #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) | ||||
|     static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
| #pragma unroll | ||||
|     for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { | ||||
| @@ -3061,13 +3061,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( | ||||
|     int * tile_y = data_mul_mat_q + mmq_x; | ||||
|     int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); | ||||
|  | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma; | ||||
|     constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>; | ||||
| #else | ||||
|     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a; | ||||
|     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>; | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) | ||||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) | ||||
|  | ||||
|     constexpr int blocks_per_iter = MMQ_ITER_K / qk; | ||||
|  | ||||
| @@ -3534,7 +3534,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int | ||||
|     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); | ||||
|     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); | ||||
|     const size_t nbs_ids = mmq_x*sizeof(int); | ||||
|     const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); | ||||
|     const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); | ||||
|     const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); | ||||
|     return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); | ||||
| } | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| #include "ggml.h" | ||||
| #include "common.cuh" | ||||
| #include "mmv.cuh" | ||||
| #include "mmvf.cuh" | ||||
| 
 | ||||
| template <typename T, typename type_acc, int ncols_dst, int block_size> | ||||
| static __global__ void mul_mat_vec( | ||||
| static __global__ void mul_mat_vec_f( | ||||
|         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, | ||||
|         const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, | ||||
|         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, | ||||
| @@ -37,7 +37,7 @@ static __global__ void mul_mat_vec( | ||||
| 
 | ||||
|     float sumf[ncols_dst] = {0.0f}; | ||||
| 
 | ||||
|     if constexpr (std::is_same<T, float>::value) { | ||||
|     if constexpr (std::is_same_v<T, float>) { | ||||
|         const float2 * x2 = (const float2 *) x; | ||||
| 
 | ||||
|         for (int col2 = tid; col2 < ncols2; col2 += block_size) { | ||||
| @@ -50,10 +50,10 @@ static __global__ void mul_mat_vec( | ||||
|                 sumf[j] += tmpx.y*tmpy.y; | ||||
|             } | ||||
|         } | ||||
|     } else if constexpr (std::is_same<T, half>::value) { | ||||
|     } else if constexpr (std::is_same_v<T, half>) { | ||||
|         const half2 * x2 = (const half2 *) x; | ||||
| 
 | ||||
|         if (std::is_same<type_acc, float>::value) { | ||||
|         if (std::is_same_v<type_acc, float>) { | ||||
|             for (int col2 = tid; col2 < ncols2; col2 += block_size) { | ||||
|                 const float2 tmpx = __half22float2(x2[col2]); | ||||
| 
 | ||||
| @@ -86,7 +86,7 @@ static __global__ void mul_mat_vec( | ||||
|             NO_DEVICE_CODE; | ||||
| #endif // FP16_AVAILABLE | ||||
|         } | ||||
|     } else if constexpr (std::is_same<T, nv_bfloat16>::value) { | ||||
|     } else if constexpr (std::is_same_v<T, nv_bfloat16>) { | ||||
|         const int * x2 = (const int *) x; | ||||
|         for (int col2 = tid; col2 < ncols2; col2 += block_size) { | ||||
|             const int tmpx = x2[col2]; | ||||
| @@ -98,7 +98,7 @@ static __global__ void mul_mat_vec( | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         static_assert(std::is_same<T, void>::value, "unsupported type"); | ||||
|         static_assert(std::is_same_v<T, void>, "unsupported type"); | ||||
|     } | ||||
| 
 | ||||
| #pragma unroll | ||||
| @@ -126,7 +126,7 @@ static __global__ void mul_mat_vec( | ||||
| } | ||||
| 
 | ||||
| template <typename T, typename type_acc, int ncols_dst> | ||||
| static void launch_mul_mat_vec_cuda( | ||||
| static void launch_mul_mat_vec_f_cuda( | ||||
|         const T * x, const float * y, const int32_t * ids, float * dst, | ||||
|         const int64_t ncols, const int64_t nrows, | ||||
|         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, | ||||
| @@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda( | ||||
|     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0); | ||||
|     const int64_t channel_ratio = nchannels_dst / nchannels_x; | ||||
|     const int64_t sample_ratio  = nsamples_dst  / nsamples_x; | ||||
|     int device; | ||||
|     int warp_size; | ||||
| 
 | ||||
|     CUDA_CHECK(cudaGetDevice(&device)); | ||||
|     warp_size = ggml_cuda_info().devices[device].warp_size; | ||||
|     const int device = ggml_cuda_get_device(); | ||||
|     const int warp_size = ggml_cuda_info().devices[device].warp_size; | ||||
| 
 | ||||
|     int64_t block_size_best = warp_size; | ||||
|     int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size); | ||||
| @@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda( | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     const int smem = warp_size*sizeof(float); | ||||
|     const int nbytes_shared = warp_size*sizeof(float); | ||||
|     const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); | ||||
|     const dim3 block_dims(block_size_best, 1, 1); | ||||
|     switch (block_size_best) { | ||||
|         case   32: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case   64: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case   96: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  128: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  160: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  192: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  224: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  256: { | ||||
|             mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>> | ||||
|             mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||
|                 (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, | ||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
| @@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda( | ||||
| } | ||||
| 
 | ||||
| template <typename T, typename type_acc> | ||||
| static void mul_mat_vec_cuda_switch_ncols_dst( | ||||
| static void mul_mat_vec_f_cuda_switch_ncols_dst( | ||||
|         const T * x, const float * y, const int32_t * ids, float * dst, | ||||
|         const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, | ||||
|         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, | ||||
| @@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst( | ||||
|         cudaStream_t stream) { | ||||
|     switch (ncols_dst) { | ||||
|         case 1: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 1> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 1> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             break; | ||||
|         case 2: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 2> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 2> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             break; | ||||
|         case 3: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 3> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 3> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             break; | ||||
|         case 4: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 4> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 4> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             break; | ||||
|         case 5: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 5> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 5> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             break; | ||||
|         case 6: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 6> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 6> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             break; | ||||
|         case 7: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 7> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 7> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             break; | ||||
|         case 8: | ||||
|             launch_mul_mat_vec_cuda<T, type_acc, 8> | ||||
|             launch_mul_mat_vec_f_cuda<T, type_acc, 8> | ||||
|                 (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
| @@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst( | ||||
| } | ||||
| 
 | ||||
| template<typename T> | ||||
| static void mul_mat_vec_cuda( | ||||
| static void mul_mat_vec_f_cuda( | ||||
|         const T * x, const float * y, const int32_t * ids, float * dst, | ||||
|         const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, | ||||
|         const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, | ||||
| @@ -292,22 +290,22 @@ static void mul_mat_vec_cuda( | ||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, | ||||
|         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, | ||||
|         enum ggml_prec prec, cudaStream_t stream) { | ||||
|     if constexpr(std::is_same<T, half>::value) { | ||||
|     if constexpr(std::is_same_v<T, half>) { | ||||
|         if (prec == GGML_PREC_DEFAULT) { | ||||
|             mul_mat_vec_cuda_switch_ncols_dst<T, half> | ||||
|             mul_mat_vec_f_cuda_switch_ncols_dst<T, half> | ||||
|                 (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, | ||||
|                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|             return; | ||||
|         } | ||||
|     } | ||||
|     mul_mat_vec_cuda_switch_ncols_dst<T, float> | ||||
|     mul_mat_vec_f_cuda_switch_ncols_dst<T, float> | ||||
|         (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, | ||||
|          nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, | ||||
|          stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
| } | ||||
| 
 | ||||
| void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { | ||||
| void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(        src1->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32); | ||||
|     GGML_ASSERT(         dst->type == GGML_TYPE_F32); | ||||
| @@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F32: { | ||||
|             const float * src0_d = (const float *) src0->data; | ||||
|             mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, | ||||
|             mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, | ||||
|                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, | ||||
|                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream()); | ||||
|         } break; | ||||
|         case GGML_TYPE_F16: { | ||||
|             const half * src0_d = (const half *) src0->data; | ||||
|             mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, | ||||
|             mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, | ||||
|                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, | ||||
|                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream()); | ||||
|         } break; | ||||
|         case GGML_TYPE_BF16: { | ||||
|             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; | ||||
|             mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, | ||||
|             mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, | ||||
|                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, | ||||
|                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream()); | ||||
|         } break; | ||||
| @@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void ggml_cuda_op_mul_mat_vec( | ||||
| void ggml_cuda_op_mul_mat_vec_f( | ||||
|     ggml_backend_cuda_context & ctx, | ||||
|     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, | ||||
|     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, | ||||
| @@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec( | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F32: { | ||||
|             const float * src0_d = (const float *) src0_dd_i; | ||||
|             mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, | ||||
|             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); | ||||
|         } break; | ||||
|         case GGML_TYPE_F16: { | ||||
|             const half * src0_d = (const half *) src0_dd_i; | ||||
|             mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, | ||||
|             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); | ||||
|         } break; | ||||
|         case GGML_TYPE_BF16: { | ||||
|             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; | ||||
|             mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, | ||||
|             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, | ||||
|                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); | ||||
|         } break; | ||||
| @@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec( | ||||
|     GGML_UNUSED(src1_padded_row_size); | ||||
| } | ||||
| 
 | ||||
| bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { | ||||
| bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { | ||||
|     if (src0_ne[0] % 2 != 0) { | ||||
|         return false; | ||||
|     } | ||||
|     switch (type) { | ||||
|         case GGML_TYPE_F32: | ||||
|             if (GGML_CUDA_CC_IS_NVIDIA(cc)) { | ||||
|                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { | ||||
|                     return ne11 <= 8; | ||||
|                 if (ampere_mma_available(cc)) { | ||||
|                     return ne11 <= 3; | ||||
|                 } | ||||
|                 if (cc >= GGML_CUDA_CC_TURING) { | ||||
|                     return ne11 <= 4; | ||||
| @@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ | ||||
|         case GGML_TYPE_F16: | ||||
|             if (GGML_CUDA_CC_IS_NVIDIA(cc)) { | ||||
|                 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); | ||||
|                 if (ampere_mma_available(cc)) { | ||||
|                     return src0_small && ne11 == 1; | ||||
|                 } | ||||
|                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { | ||||
|                     return src0_small && ne11 <= 4; | ||||
|                 } | ||||
| @@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ | ||||
|         case GGML_TYPE_BF16: | ||||
|             if (GGML_CUDA_CC_IS_NVIDIA(cc)) { | ||||
|                 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); | ||||
|                 if (ampere_mma_available(cc)) { | ||||
|                     return src0_small && ne11 == 1; | ||||
|                 } | ||||
|                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { | ||||
|                     return src0_small && ne11 <= 4; | ||||
|                 } | ||||
| @@ -1,11 +1,11 @@ | ||||
| #include "common.cuh" | ||||
| 
 | ||||
| void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); | ||||
| void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); | ||||
| 
 | ||||
| void ggml_cuda_op_mul_mat_vec( | ||||
| void ggml_cuda_op_mul_mat_vec_f( | ||||
|     ggml_backend_cuda_context & ctx, | ||||
|     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, | ||||
|     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, | ||||
|     const int64_t src1_padded_row_size, cudaStream_t stream); | ||||
| 
 | ||||
| bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); | ||||
| bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); | ||||
							
								
								
									
										1
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							| @@ -200,6 +200,7 @@ | ||||
| #endif | ||||
|  | ||||
| typedef hip_bfloat16 nv_bfloat16; | ||||
| typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix | ||||
|  | ||||
| typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); | ||||
| typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); | ||||
|   | ||||
							
								
								
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							| @@ -137,4 +137,5 @@ | ||||
| #define cudaStreamEndCapture musaStreamEndCapture | ||||
| #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor | ||||
|  | ||||
| typedef mt_bfloat16 nv_bfloat16; | ||||
| typedef __mt_bfloat16 nv_bfloat16; | ||||
| typedef __mt_bfloat162 nv_bfloat162; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler