mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
ggml: adds CONV_2D op and direct GEMM Vulkan implementation (#14316)
* ggml/ggml-vulkan/test-backend-ops: adds CONV_2D for Vulkan * ggml-vulkan: adds f32 scalar shader to compute 2D convolution directly with gemm (no need for im2col), * test-backend-ops: adds test_case_ref to check the validity/performance of ops against reference implementations having different graphs, adds tests * * Performance fixes: minimized branch divergence, uses collectives to eliminate redundant calculation, macros removed. * Kernel shared memory size check * Updates test-backend-ops to support graphs for performance measurement. * * Apple/Win32 compile errors fixed * Subgroup size used to determine tile size -> fixes llvmpipe errors. * Collectives disabled by default. * Intel support is disabled as the performance is poor. * Conv2d enabled for Intel with disabled collectives, disabled for Apple * test-backend-ops modifications are reverted * Trailing spaces and missing override fixed. * Triggering pipeline relaunch. * Code formatted with .clang-format.
This commit is contained in:
committed by
GitHub
parent
90083283ec
commit
a979ca22db
265
ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
Normal file
265
ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
Normal file
@@ -0,0 +1,265 @@
|
||||
#version 450
|
||||
|
||||
#ifdef USE_COLLECTIVES
|
||||
# extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
// Make spec constant
|
||||
#define SHMEM_PAD 0
|
||||
|
||||
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
||||
layout(binding = 0) readonly buffer A {
|
||||
A_TYPE knl_data[];
|
||||
}; // src0 - kernel: [KW, KH, Cin, Cout]
|
||||
|
||||
layout(binding = 1) readonly buffer B {
|
||||
B_TYPE src_data[];
|
||||
}; // src1 - input: [W, H, Cin, N] -- channel_first format
|
||||
|
||||
layout(binding = 2) writeonly buffer D {
|
||||
D_TYPE dst_data[];
|
||||
}; // dst - result: [OW, OH, Cout, N]
|
||||
|
||||
layout(push_constant) uniform parameter {
|
||||
// I/O channels, batch size
|
||||
uint32_t Cout;
|
||||
uint32_t Cin;
|
||||
uint32_t N;
|
||||
|
||||
// Tensor spatial sizes: kernel, input, output
|
||||
uint32_t KW;
|
||||
uint32_t KH;
|
||||
uint32_t W;
|
||||
uint32_t H;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
|
||||
// Parameters: stride, padding, dilation - 0=y, 1=x
|
||||
uint32_t s0;
|
||||
uint32_t s1;
|
||||
uint32_t p0;
|
||||
uint32_t p1;
|
||||
uint32_t d0;
|
||||
uint32_t d1;
|
||||
|
||||
// Strides in elements
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
}
|
||||
|
||||
p;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
// Blocktile sizes
|
||||
layout(constant_id = 1) const uint BS_K = 128;
|
||||
layout(constant_id = 2) const uint BS_CRS = 16;
|
||||
layout(constant_id = 3) const uint BS_NPQ = 128;
|
||||
// Thread-tile sizes
|
||||
layout(constant_id = 4) const uint TS_K = 8;
|
||||
layout(constant_id = 5) const uint use_collectives = 1;
|
||||
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
||||
|
||||
uint splitWork(uint work_size, uint block_size) {
|
||||
return (block_size + work_size - 1) / block_size;
|
||||
}
|
||||
|
||||
uint32_t K = p.Cout;
|
||||
uint32_t CRS = p.Cin * p.KH * p.KW;
|
||||
uint32_t NPQ = p.N * p.OH * p.OW;
|
||||
|
||||
uint32_t n_elems_out = K * NPQ;
|
||||
|
||||
// Number of blocktiles per input
|
||||
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
|
||||
|
||||
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
|
||||
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
|
||||
|
||||
const uint32_t Ash_numel = BS_K * BS_CRS;
|
||||
const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
|
||||
|
||||
const uint32_t Ash_len = BS_K * Ash_stride;
|
||||
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
|
||||
|
||||
shared float Ash[Ash_len]; // K x CRS
|
||||
shared float Bsh[Bsh_len]; // CRS x NPQ
|
||||
|
||||
// Threadtile sizes
|
||||
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
|
||||
|
||||
// Number of threadtiles per blocktile
|
||||
const uint32_t NT_K = BS_K / TS_K;
|
||||
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
|
||||
|
||||
float regA[TS_K];
|
||||
float regB[TS_NPQ];
|
||||
float regC[TS_K][TS_NPQ];
|
||||
|
||||
/*
|
||||
Compute
|
||||
KxCRS @ CRSxNPQ = K x NPQ
|
||||
K=Cout
|
||||
C=Cin
|
||||
R,S=KH,KW
|
||||
P,Q=OH,OW
|
||||
*/
|
||||
|
||||
uint32_t B_idx_K = gl_WorkGroupID.x;
|
||||
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
|
||||
|
||||
uint32_t T_y = tid / NT_NPQ;
|
||||
uint32_t T_x = tid % NT_NPQ;
|
||||
|
||||
uint32_t Ar = tid / BS_CRS;
|
||||
uint32_t Ac = tid % BS_CRS;
|
||||
const uint32_t ArpWg = WG_SIZE / BS_CRS;
|
||||
|
||||
uint32_t Br = tid / BS_NPQ;
|
||||
uint32_t Bc = tid % BS_NPQ;
|
||||
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
|
||||
|
||||
void main() {
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = 0.0;
|
||||
}
|
||||
}
|
||||
/* Advance block in CRS dim */
|
||||
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
||||
uint32_t CRS_idx_a;
|
||||
uint32_t Cin_idx_a;
|
||||
uint32_t KH_idx_a;
|
||||
uint32_t KW_idx_a;
|
||||
|
||||
#ifdef USE_COLLECTIVES
|
||||
uint32_t cached_CRS_idx;
|
||||
uint32_t cached_Cin_idx;
|
||||
uint32_t cached_KH_idx;
|
||||
uint32_t cached_KW_idx;
|
||||
if (use_collectives == 1) {
|
||||
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
|
||||
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
|
||||
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
|
||||
cached_KH_idx = cached_CRS_remainder / p.KW;
|
||||
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
|
||||
|
||||
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
|
||||
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
|
||||
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
|
||||
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
|
||||
} else {
|
||||
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
||||
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
|
||||
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
||||
KH_idx_a = CRS_remainder / p.KW;
|
||||
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
||||
}
|
||||
#else
|
||||
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
||||
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
|
||||
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
||||
KH_idx_a = CRS_remainder / p.KW;
|
||||
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
||||
#endif
|
||||
|
||||
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
||||
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
||||
uint32_t B_ly = r_offset + Ar;
|
||||
uint32_t B_lx = Ac;
|
||||
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
||||
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
|
||||
float val = knl_data[knl_idx];
|
||||
if (K_idx >= K || CRS_idx_a >= CRS) {
|
||||
val = 0.0;
|
||||
}
|
||||
Ash[B_ly * Ash_stride + B_lx] = val;
|
||||
}
|
||||
/* Load input to B_block: (BS_CRS x BS_NPQ) */
|
||||
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
|
||||
uint32_t B_ly = r_offset + Br; /* Row index of B block */
|
||||
uint32_t B_lx = Bc;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
|
||||
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
|
||||
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
|
||||
uint32_t OH_idx = NPQ_remainder / p.OW;
|
||||
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
|
||||
|
||||
uint32_t CRS_idx_b;
|
||||
uint32_t Cin_idx_b;
|
||||
uint32_t KH_idx_b;
|
||||
uint32_t KW_idx_b;
|
||||
#ifdef USE_COLLECTIVES
|
||||
if (use_collectives == 1) {
|
||||
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
|
||||
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
|
||||
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
|
||||
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
|
||||
} else {
|
||||
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
||||
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
|
||||
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
||||
KH_idx_b = CRS_remainder / p.KW;
|
||||
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
||||
}
|
||||
#else
|
||||
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
||||
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
|
||||
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
||||
KH_idx_b = CRS_remainder / p.KW;
|
||||
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
||||
#endif
|
||||
|
||||
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
|
||||
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
|
||||
uint32_t src_idx =
|
||||
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
|
||||
float val = src_data[src_idx];
|
||||
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
|
||||
val = 0.0;
|
||||
}
|
||||
Bsh[B_ly * Bsh_stride + B_lx] = val;
|
||||
}
|
||||
barrier();
|
||||
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
|
||||
}
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
|
||||
}
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
/* Save C* */
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
|
||||
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
|
||||
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
|
||||
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
|
||||
if (K_idx < K && NPQ_idx < NPQ) {
|
||||
dst_data[dst_idx] = regC[T_ly][T_lx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -655,6 +655,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user