mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-15 11:17:31 +00:00
musa: fix all warnings, re-enable -DLLAMA_FATAL_WARNINGS=ON in ci and update doc (#12611)
* musa: fix all warnings Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * musa: enable -DLLAMA_FATAL_WARNINGS=ON in run.sh Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * musa: update ci doc (install ccache) Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * fix Windows build issue Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Address review comments Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Address review comments Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
@@ -151,7 +151,7 @@ static __global__ void mul_mat_vec_q(
|
||||
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
|
||||
float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
|
||||
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
@@ -197,10 +197,12 @@ static __global__ void mul_mat_vec_q(
|
||||
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
||||
}
|
||||
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
|
||||
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(nrows_x);
|
||||
}
|
||||
|
||||
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
|
||||
|
||||
Reference in New Issue
Block a user