mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
Fix more int overflow during quant (PPL/CUDA). (#6563)
* Fix more int overflow during quant. * Fix some more int overflow in softmax. * Revert back to int64_t.
This commit is contained in:
@@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||
extern __shared__ float data_soft_max_f32[];
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
||||
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
|
||||
|
||||
float max_val = -INFINITY;
|
||||
|
||||
@@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||
break;
|
||||
}
|
||||
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
const int64_t ix = (int64_t)rowx*ncols + col;
|
||||
const int64_t iy = (int64_t)rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
||||
|
||||
@@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||
return;
|
||||
}
|
||||
|
||||
const int idst = rowx*ncols + col;
|
||||
const int64_t idst = (int64_t)rowx*ncols + col;
|
||||
dst[idst] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user