mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			179 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			179 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| #include "getrows.cuh"
 | |
| #include "dequantize.cuh"
 | |
| 
 | |
| template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 | |
| static __global__ void k_get_rows(
 | |
|             const void * src0, const int32_t * src1, dst_t * dst,
 | |
|             int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
 | |
|             /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
 | |
|             /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
 | |
|             /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
 | |
|             size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
 | |
| 
 | |
|     const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
 | |
|     const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
 | |
|     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
 | |
|     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 | |
| 
 | |
|     if (i00 >= ne00) {
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 | |
| 
 | |
|     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
 | |
|     const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
 | |
| 
 | |
|     const int ib = i00/qk; // block index
 | |
|     const int iqs = (i00%qk)/qr; // quant index
 | |
|     const int iybs = i00 - i00%qk; // dst block start index
 | |
|     const int y_offset = qr == 1 ? 1 : qk/2;
 | |
| 
 | |
|     // dequantize
 | |
|     dfloat2 v;
 | |
|     dequantize_kernel(src0_row, ib, iqs, v);
 | |
| 
 | |
|     dst_row[iybs + iqs + 0]        = v.x;
 | |
|     dst_row[iybs + iqs + y_offset] = v.y;
 | |
| }
 | |
| 
 | |
| template<typename src0_t, typename dst_t>
 | |
| static __global__ void k_get_rows_float(
 | |
|             const src0_t * src0, const int32_t * src1, dst_t * dst,
 | |
|             int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
 | |
|             /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
 | |
|             /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
 | |
|             /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
 | |
|             size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
 | |
| 
 | |
|     const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
 | |
|     const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
 | |
|     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
 | |
|     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 | |
| 
 | |
|     if (i00 >= ne00) {
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 | |
| 
 | |
|     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
 | |
|     const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
 | |
| 
 | |
|     dst_row[i00] = src0_row[i00];
 | |
| }
 | |
| 
 | |
| template<int qk, int qr, dequantize_kernel_t dq>
 | |
| static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
 | |
|                             const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 | |
| 
 | |
|     GGML_TENSOR_BINARY_OP_LOCALS
 | |
| 
 | |
|     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
 | |
|     const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
 | |
|     const dim3 block_nums(block_num_x, ne10, ne11*ne12);
 | |
| 
 | |
|     // strides in elements
 | |
|     //const size_t s0 = nb0 / ggml_element_size(dst);
 | |
|     const size_t s1 = nb1 / ggml_element_size(dst);
 | |
|     const size_t s2 = nb2 / ggml_element_size(dst);
 | |
|     const size_t s3 = nb3 / ggml_element_size(dst);
 | |
| 
 | |
|     const size_t s10 = nb10 / ggml_element_size(src1);
 | |
|     const size_t s11 = nb11 / ggml_element_size(src1);
 | |
|     const size_t s12 = nb12 / ggml_element_size(src1);
 | |
|     //const size_t s13 = nb13 / ggml_element_size(src1);
 | |
| 
 | |
|     GGML_ASSERT(ne00 % 2 == 0);
 | |
| 
 | |
|     k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
 | |
|             src0_dd, src1_dd, dst_dd,
 | |
|             ne00, /*ne01, ne02, ne03,*/
 | |
|             /*ne10, ne11,*/ ne12, /*ne13,*/
 | |
|             /* s0,*/ s1, s2, s3,
 | |
|             /* nb00,*/ nb01, nb02, nb03,
 | |
|             s10, s11, s12/*, s13*/);
 | |
| 
 | |
|     GGML_UNUSED(dst);
 | |
| }
 | |
| 
 | |
| template<typename src0_t>
 | |
| static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
 | |
|                                 const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 | |
| 
 | |
|     GGML_TENSOR_BINARY_OP_LOCALS
 | |
| 
 | |
|     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
 | |
|     const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
 | |
|     const dim3 block_nums(block_num_x, ne10, ne11*ne12);
 | |
| 
 | |
|     // strides in elements
 | |
|     //const size_t s0 = nb0 / ggml_element_size(dst);
 | |
|     const size_t s1 = nb1 / ggml_element_size(dst);
 | |
|     const size_t s2 = nb2 / ggml_element_size(dst);
 | |
|     const size_t s3 = nb3 / ggml_element_size(dst);
 | |
| 
 | |
|     const size_t s10 = nb10 / ggml_element_size(src1);
 | |
|     const size_t s11 = nb11 / ggml_element_size(src1);
 | |
|     const size_t s12 = nb12 / ggml_element_size(src1);
 | |
|     //const size_t s13 = nb13 / ggml_element_size(src1);
 | |
| 
 | |
|     k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
 | |
|             src0_dd, src1_dd, dst_dd,
 | |
|             ne00, /*ne01, ne02, ne03,*/
 | |
|             /*ne10, ne11,*/ ne12, /*ne13,*/
 | |
|             /* s0,*/ s1, s2, s3,
 | |
|             /* nb00,*/ nb01, nb02, nb03,
 | |
|             s10, s11, s12/*, s13*/);
 | |
| 
 | |
|     GGML_UNUSED(dst);
 | |
| }
 | |
| 
 | |
| void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 | |
|     const ggml_tensor * src0 = dst->src[0];
 | |
|     const ggml_tensor * src1 = dst->src[1];
 | |
|     const float * src0_d = (const float *)src0->data;
 | |
|     const float * src1_d = (const float *)src1->data;
 | |
|     float * dst_d = (float *)dst->data;
 | |
|     cudaStream_t stream = ctx.stream();
 | |
| 
 | |
| 
 | |
|     GGML_ASSERT(src1->type == GGML_TYPE_I32);
 | |
|     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 | |
| 
 | |
|     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
 | |
|     GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
 | |
|     GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 | |
| 
 | |
|     const int32_t * src1_i32 = (const int32_t *) src1_d;
 | |
| 
 | |
|     switch (src0->type) {
 | |
|         case GGML_TYPE_F16:
 | |
|             get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
 | |
|             break;
 | |
|         case GGML_TYPE_F32:
 | |
|             get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
 | |
|             break;
 | |
|         case GGML_TYPE_Q4_0:
 | |
|             get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
 | |
|             break;
 | |
|         case GGML_TYPE_Q4_1:
 | |
|             get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
 | |
|             break;
 | |
|         case GGML_TYPE_Q5_0:
 | |
|             get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
 | |
|             break;
 | |
|         case GGML_TYPE_Q5_1:
 | |
|             get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
 | |
|             break;
 | |
|         case GGML_TYPE_Q8_0:
 | |
|             get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
 | |
|             break;
 | |
|         default:
 | |
|             // TODO: k-quants
 | |
|             fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
 | |
|             GGML_ASSERT(false);
 | |
|             break;
 | |
|     }
 | |
| }
 | 
