CUDA: Add mul_mat_id support for the mmf kernel (#15767)

* CUDA: Add mul_mat_id support the mmf

Add support for mul_mat_id for bs < 16

* Review: use warp_size, fix should_use_mmf condition

* Launch one block per expert, stride along n_expert_used

* templatize mul_mat_id

* Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids

* Reduce compile times by dividing mmf into f16, bf16 and f32 variants

* Divide mmf by ncols_dst

* Add missing files

* Fix MUSA/HIP builds
This commit is contained in:
Aman Gupta
2025-09-09 14:38:02 +08:00
committed by GitHub
parent 550cf726e1
commit a972faebed
23 changed files with 603 additions and 350 deletions

View File

@@ -1,3 +1,4 @@
#pragma once
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under: