Hide latency of bias and gate-loading (#16847)

This is realised by loading them into registers before computation of
the dot-product, effectively batching them together with said
dot-product. As a lot of threads are alive here, the warp scheduler has
enough threads available to effectively hide the cost of additionally
loading those two floats.
This commit is contained in:
Oliver Simons
2025-10-30 04:34:15 +01:00
committed by GitHub
parent b9ce940177
commit 8b11deea46

View File

@@ -190,12 +190,28 @@ static __global__ void mul_mat_vec_q(
const uint32_t channel_bias = ids ? channel_x : channel_dst; const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
if constexpr (has_fusion) { if constexpr (has_fusion) {
if (use_bias) { if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here
// 2. load only on threads that won't die after partial sum calculation
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
for (int j = 0; j < ncols_dst; ++j) {
x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
}
}
} }
if (use_gate_bias) { if (use_gate_bias) {
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
for (int j = 0; j < ncols_dst; ++j) {
gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
}
}
} }
} }
@@ -283,12 +299,12 @@ static __global__ void mul_mat_vec_q(
float result = tmp[j][threadIdx.x]; float result = tmp[j][threadIdx.x];
if constexpr (has_fusion) { if constexpr (has_fusion) {
if (use_bias) { if (use_bias) {
result += x_bias[j*stride_col_dst + threadIdx.x]; result += x_biases[j][threadIdx.x];
} }
if (use_gate) { if (use_gate) {
float gate_value = tmp_gate[j][threadIdx.x]; float gate_value = tmp_gate[j][threadIdx.x];
if (use_gate_bias) { if (use_gate_bias) {
gate_value += gate_bias[j*stride_col_dst + threadIdx.x]; gate_value += gate_biases[j][threadIdx.x];
} }
switch (active_glu) { switch (active_glu) {
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU: