CUDA: Add mul_mat_id support for the mmf kernel (#15767)

* CUDA: Add mul_mat_id support the mmf

Add support for mul_mat_id for bs < 16

* Review: use warp_size, fix should_use_mmf condition

* Launch one block per expert, stride along n_expert_used

* templatize mul_mat_id

* Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids

* Reduce compile times by dividing mmf into f16, bf16 and f32 variants

* Divide mmf by ncols_dst

* Add missing files

* Fix MUSA/HIP builds
This commit is contained in:
Aman Gupta
2025-09-09 14:38:02 +08:00
committed by GitHub
parent 550cf726e1
commit a972faebed
23 changed files with 603 additions and 350 deletions

View File

@@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
DECL_MMQ_CASE({type});
"""
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE({type});
"""
def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower()
@@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
f.write(SOURCE_MMQ.format(type=type))
for type in range(1, 17):
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
f.write(SOURCE_MMF.format(type=type))

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(1);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(10);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(11);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(12);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(13);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(14);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(15);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(16);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(2);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(3);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(4);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(5);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(6);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(7);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(8);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(9);