fix MUSA compiler warning (#12704)

* fix MUSA compiler warning

* replace (void) with GGML_UNUSED
This commit is contained in:
a3sh
2025-04-03 15:32:55 +08:00
committed by GitHub
parent 65cfe136a0
commit 193c3e03a6
2 changed files with 44 additions and 49 deletions

View File

@@ -1,10 +1,5 @@
#include "ssm-scan.cuh"
// #include <cuda_runtime.h>
// static __device__ void global_to_shared(const float *src, float *dst) {
// asm volatile("cp.async.");
// }
template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
@@ -12,7 +7,9 @@ __global__ void __launch_bounds__(splitD, 2)
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * __restrict__ dst, const int D, const int L, const int B) {
float * __restrict__ dst, const int64_t L) {
GGML_UNUSED(src1_nb0);
GGML_UNUSED(src2_nb0);
const int bidx = blockIdx.x; // split along B
const int bidy = blockIdx.y; // split along D
const int tid = threadIdx.x;
@@ -25,12 +22,12 @@ __global__ void __launch_bounds__(splitD, 2)
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2));
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
@@ -46,7 +43,7 @@ __global__ void __launch_bounds__(splitD, 2)
// can N not be 16? for example 32?
if (N == 16) {
#pragma unroll
for (int i = 0; i < splitD / 4; i += 2) {
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
@@ -54,7 +51,7 @@ __global__ void __launch_bounds__(splitD, 2)
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
#pragma unroll
for (int i = 0; i < splitD / 4; i += 2) {
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
@@ -62,7 +59,7 @@ __global__ void __launch_bounds__(splitD, 2)
__syncthreads();
for (int i = 0; i < L; i++) {
for (int64_t i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
@@ -70,7 +67,7 @@ __global__ void __launch_bounds__(splitD, 2)
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (int j = 0; j < N; j++) {
for (size_t j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
@@ -90,7 +87,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) {
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
cudaStream_t stream) {
const int threads = 128;
// todo: consider D cannot be divided,does this situation exist?
GGML_ASSERT(D % threads == 0);
@@ -99,7 +97,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
if (N == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B);
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
} else {
GGML_ABORT("doesn't support N!=16.");
}