mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
CUDA: Add mul_mat_id support for the mmf kernel (#15767)
* CUDA: Add mul_mat_id support the mmf Add support for mul_mat_id for bs < 16 * Review: use warp_size, fix should_use_mmf condition * Launch one block per expert, stride along n_expert_used * templatize mul_mat_id * Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids * Reduce compile times by dividing mmf into f16, bf16 and f32 variants * Divide mmf by ncols_dst * Add missing files * Fix MUSA/HIP builds
This commit is contained in:
@@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
|
||||
DECL_MMQ_CASE({type});
|
||||
"""
|
||||
|
||||
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE({type});
|
||||
"""
|
||||
|
||||
|
||||
def get_short_name(long_quant_name):
|
||||
return long_quant_name.replace("GGML_TYPE_", "").lower()
|
||||
@@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
f.write(SOURCE_MMQ.format(type=type))
|
||||
|
||||
for type in range(1, 17):
|
||||
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
|
||||
f.write(SOURCE_MMF.format(type=type))
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(1);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(10);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(11);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(12);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(13);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(14);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(15);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(16);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(2);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(3);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(4);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(5);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(6);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(7);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(8);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmf.cuh"
|
||||
|
||||
DECL_MMF_CASE(9);
|
||||
Reference in New Issue
Block a user