mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	cuda: use simpler loop in get_rows
This commit is contained in:
		@@ -11,19 +11,14 @@ static __global__ void k_get_rows(
 | 
				
			|||||||
        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
 | 
					        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
 | 
				
			||||||
        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
 | 
					        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
 | 
				
			||||||
        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
 | 
					        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
 | 
				
			||||||
        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) {
 | 
					        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) {
 | 
					    for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
 | 
				
			||||||
        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
 | 
					        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
 | 
				
			||||||
        const int i00 = (iy * blockDim.x + threadIdx.x)*2;
 | 
					 | 
				
			||||||
        const int i10 =  blockIdx.x;
 | 
					        const int i10 =  blockIdx.x;
 | 
				
			||||||
        const int i11 =  blockIdx.z / ne12;
 | 
					        const int i11 =  blockIdx.z / ne12;
 | 
				
			||||||
        const int i12 =  blockIdx.z % ne12;
 | 
					        const int i12 =  blockIdx.z % ne12;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (i00 >= ne00) {
 | 
					 | 
				
			||||||
            return;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 | 
					        const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
 | 
					        dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
 | 
				
			||||||
@@ -35,7 +30,7 @@ static __global__ void k_get_rows(
 | 
				
			|||||||
        const int y_offset = qr == 1 ? 1 : qk/2;
 | 
					        const int y_offset = qr == 1 ? 1 : qk/2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // dequantize
 | 
					        // dequantize
 | 
				
			||||||
        float2 v;
 | 
					        dfloat2 v;
 | 
				
			||||||
        dequantize_kernel(src0_row, ib, iqs, v);
 | 
					        dequantize_kernel(src0_row, ib, iqs, v);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
 | 
					        dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
 | 
				
			||||||
@@ -50,12 +45,10 @@ static __global__ void k_get_rows_float(
 | 
				
			|||||||
        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
 | 
					        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
 | 
				
			||||||
        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
 | 
					        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
 | 
				
			||||||
        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
 | 
					        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
 | 
				
			||||||
        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) {
 | 
					        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) {
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
 | 
				
			||||||
        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
 | 
					        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
 | 
				
			||||||
        const int i00 = iy * blockDim.x + threadIdx.x;
 | 
					 | 
				
			||||||
        const int i10 = blockIdx.x;
 | 
					        const int i10 = blockIdx.x;
 | 
				
			||||||
        const int i11 = blockIdx.z / ne12;
 | 
					        const int i11 = blockIdx.z / ne12;
 | 
				
			||||||
        const int i12 = blockIdx.z % ne12;
 | 
					        const int i12 = blockIdx.z % ne12;
 | 
				
			||||||
@@ -126,7 +119,7 @@ static void get_rows_cuda_q(
 | 
				
			|||||||
        /*ne10, ne11,*/ ne12, /*ne13,*/
 | 
					        /*ne10, ne11,*/ ne12, /*ne13,*/
 | 
				
			||||||
        /* s0,*/ s1, s2, s3,
 | 
					        /* s0,*/ s1, s2, s3,
 | 
				
			||||||
        /* nb00,*/ nb01, nb02, nb03,
 | 
					        /* nb00,*/ nb01, nb02, nb03,
 | 
				
			||||||
        s10, s11, s12/*, s13*/, block_num_y);
 | 
					        s10, s11, s12/*, s13*/);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template<typename src0_t, typename dst_t>
 | 
					template<typename src0_t, typename dst_t>
 | 
				
			||||||
@@ -157,7 +150,7 @@ static void get_rows_cuda_float(
 | 
				
			|||||||
        /*ne10, ne11,*/ ne12, /*ne13,*/
 | 
					        /*ne10, ne11,*/ ne12, /*ne13,*/
 | 
				
			||||||
        /* s0,*/ s1, s2, s3,
 | 
					        /* s0,*/ s1, s2, s3,
 | 
				
			||||||
        /* nb00,*/ nb01, nb02, nb03,
 | 
					        /* nb00,*/ nb01, nb02, nb03,
 | 
				
			||||||
        s10, s11, s12/*, s13*/, block_num_y);
 | 
					        s10, s11, s12/*, s13*/);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename dst_t>
 | 
					template <typename dst_t>
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user