#include "ggml.h" #include "mmf.cuh" #include "mmid.cuh" void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS; const size_t ts_src0 = ggml_type_size(src0->type); const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); GGML_ASSERT(ne13 == ne3); GGML_ASSERT( nb00 == ts_src0); GGML_ASSERT( nb10 == ts_src1); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); GGML_ASSERT( nb0 == ts_dst); const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; const int64_t s01 = src0->nb[1] / ts_src0; const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s1 = dst->nb[1] / ts_dst; const int64_t s02 = src0->nb[2] / ts_src0; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s2 = dst->nb[2] / ts_dst; const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; mmf_ids_data ids_info{}; mmf_ids_data * ids_info_ptr = nullptr; ggml_cuda_pool_alloc ids_src_compact_dev; ggml_cuda_pool_alloc ids_dst_compact_dev; ggml_cuda_pool_alloc expert_bounds_dev; // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_dst = ids ? ne1 : ne2; const int64_t stride_col_dst = ids ? s2 : s1; const int64_t stride_col_y = ids ? s12 : s11; const int64_t stride_channel_dst = ids ? s1 : s2; int64_t stride_channel_y = ids ? s11 : s12; int64_t nchannels_y = ids ? ne11 : ne12; //mul_mat_id: handle broadcast if (ids && nchannels_y == 1) { stride_channel_y = 0; nchannels_y = ids->ne[0]; } if (ids && ncols_dst > 16) { const int64_t n_expert_used = ids->ne[0]; const int64_t n_experts = ne02; const int64_t n_tokens = ne12; const int64_t ne_get_rows = n_tokens * n_expert_used; ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows); ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows); expert_bounds_dev.alloc(ctx.pool(), n_experts + 1); const int si1 = static_cast(ids_s1); const int sis1 = static_cast(src1->nb[2] / src1->nb[1]); GGML_ASSERT(sis1 > 0); ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), static_cast(n_experts), static_cast(n_tokens), static_cast(n_expert_used), static_cast(ne11), si1, sis1, ctx.stream()); CUDA_CHECK(cudaGetLastError()); ids_info.ids_src_compact = ids_src_compact_dev.get(); ids_info.ids_dst_compact = ids_dst_compact_dev.get(); ids_info.expert_bounds_dev = expert_bounds_dev.get(); ids_info.n_experts = static_cast(n_experts); ids_info.sis1 = sis1; ids_info_ptr = &ids_info; } switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } } bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) { if (ggml_is_quantized(type)) { return false; } if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { return false; } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } if (mul_mat_id) { if (src0_ne[1] <= 1024 && src1_ncols > 512) { return false; } else if(src0_ne[1] > 1024 && src1_ncols > 128) { return false; } } else { if (src1_ncols > 16) { return false; } } switch (type) { case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: return turing_mma_available(cc); case GGML_TYPE_BF16: return ampere_mma_available(cc); default: return false; } }