mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
CUDA: some micro-optimizations in mmf.cuh for mul_mat_id (#15926)
This commit is contained in:
@@ -57,31 +57,33 @@ static __global__ void mul_mat_f(
|
|||||||
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||||
|
|
||||||
if constexpr (has_ids) {
|
if constexpr (has_ids) {
|
||||||
__shared__ int has_any;
|
int found = 0;
|
||||||
if (threadIdx.y == 0) {
|
|
||||||
int local_has_any = 0;
|
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||||
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
|
const int j = j0 + threadIdx.y;
|
||||||
int slot = -1;
|
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
||||||
for (int k = 0; k < nchannels_dst; ++k) {
|
|
||||||
const int idv = ids[j*stride_row_id + k*stride_col_id];
|
if (threadIdx.x == 0) {
|
||||||
if (idv == expert_idx) {
|
slot_map[j] = -1;
|
||||||
slot = k;
|
}
|
||||||
break;
|
|
||||||
}
|
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
|
||||||
}
|
int match = id_row[k*stride_col_id] == expert_idx;
|
||||||
if (j < cols_per_block) {
|
|
||||||
local_has_any |= (slot >= 0);
|
if (match) {
|
||||||
slot_map[j] = slot;
|
slot_map[j] = k;
|
||||||
|
found = 1;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
has_any = warp_reduce_any(local_has_any);
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
if (has_any == 0) {
|
if (!__syncthreads_or(found)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||||
tile_A A[ntA][warp_size / tile_A::J];
|
tile_A A[ntA][warp_size / tile_A::J];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -106,14 +108,7 @@ static __global__ void mul_mat_f(
|
|||||||
if constexpr (!has_ids) {
|
if constexpr (!has_ids) {
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
||||||
} else {
|
} else {
|
||||||
float val = 0.0f;
|
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
||||||
if (j < cols_per_block) {
|
|
||||||
const int slot = slot_map[j];
|
|
||||||
if (slot >= 0) {
|
|
||||||
val = y[slot*stride_channel_y + j*stride_col_y + col];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||||
@@ -125,14 +120,7 @@ static __global__ void mul_mat_f(
|
|||||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||||
} else {
|
} else {
|
||||||
float2 tmp = make_float2(0.0f, 0.0f);
|
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||||
if (j < cols_per_block) {
|
|
||||||
const int slot = slot_map[j];
|
|
||||||
if (slot >= 0) {
|
|
||||||
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
|
|
||||||
tmp = y2_slot[j*stride_col_y + col];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids(
|
|||||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||||
if (ids) {
|
if (ids) {
|
||||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user