mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
CUDA: Remove unneded bias/gate dims in fused mmvq (#16858)
* CUDA: Remove unneded bias/gate dims in fused mmvq Pointed out [here](https://github.com/ggml-org/llama.cpp/pull/16847#discussion_r2476798989) that only a single value is needed per target col per thread * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Fix "Error 991-D: extra braces are nonstandard" during compilation --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -190,8 +190,8 @@ static __global__ void mul_mat_vec_q(
|
|||||||
|
|
||||||
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||||
|
|
||||||
float x_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
|
float x_biases[ncols_dst] = { 0.0f };
|
||||||
float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
|
float gate_biases[ncols_dst] = { 0.0f };
|
||||||
if constexpr (has_fusion) {
|
if constexpr (has_fusion) {
|
||||||
if (use_bias) {
|
if (use_bias) {
|
||||||
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||||
@@ -199,8 +199,9 @@ static __global__ void mul_mat_vec_q(
|
|||||||
// 2. load only on threads that won't die after partial sum calculation
|
// 2. load only on threads that won't die after partial sum calculation
|
||||||
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
|
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
|
||||||
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||||
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
|
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,8 +209,9 @@ static __global__ void mul_mat_vec_q(
|
|||||||
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||||
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
|
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
|
||||||
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||||
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
|
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -299,12 +301,12 @@ static __global__ void mul_mat_vec_q(
|
|||||||
float result = tmp[j][threadIdx.x];
|
float result = tmp[j][threadIdx.x];
|
||||||
if constexpr (has_fusion) {
|
if constexpr (has_fusion) {
|
||||||
if (use_bias) {
|
if (use_bias) {
|
||||||
result += x_biases[j][threadIdx.x];
|
result += x_biases[j];
|
||||||
}
|
}
|
||||||
if (use_gate) {
|
if (use_gate) {
|
||||||
float gate_value = tmp_gate[j][threadIdx.x];
|
float gate_value = tmp_gate[j][threadIdx.x];
|
||||||
if (use_gate_bias) {
|
if (use_gate_bias) {
|
||||||
gate_value += gate_biases[j][threadIdx.x];
|
gate_value += gate_biases[j];
|
||||||
}
|
}
|
||||||
switch (active_glu) {
|
switch (active_glu) {
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
|||||||
Reference in New Issue
Block a user