mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	cuda : deduplicated dequantization code (#1453)
This commit is contained in:
		
							
								
								
									
										154
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										154
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							@@ -83,7 +83,8 @@ typedef struct {
 | 
			
		||||
} block_q8_0;
 | 
			
		||||
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
 | 
			
		||||
 | 
			
		||||
#define CUDA_DMMV_BLOCK_SIZE 32
 | 
			
		||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 | 
			
		||||
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
 | 
			
		||||
 | 
			
		||||
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
 | 
			
		||||
    const block_q4_0 * x = (const block_q4_0 *) vx;
 | 
			
		||||
@@ -170,104 +171,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
 | 
			
		||||
    v1 = __half2float(x[ib + 1]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
 | 
			
		||||
    static const int qk = QK4_0;
 | 
			
		||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 | 
			
		||||
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
 | 
			
		||||
    const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
 | 
			
		||||
 | 
			
		||||
    const block_q4_0 * x = (const block_q4_0 *) vx;
 | 
			
		||||
 | 
			
		||||
    const int i = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
    const float d = x[i].d;
 | 
			
		||||
 | 
			
		||||
    for (int j = 0; j < qk/2; ++j) {
 | 
			
		||||
        const int x0 = (x[i].qs[j] & 0xf) - 8;
 | 
			
		||||
        const int x1 = (x[i].qs[j] >>  4) - 8;
 | 
			
		||||
 | 
			
		||||
        y[i*qk + j + 0   ] = x0*d;
 | 
			
		||||
        y[i*qk + j + qk/2] = x1*d;
 | 
			
		||||
    if (i >= k) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
 | 
			
		||||
    static const int qk = QK4_1;
 | 
			
		||||
    const int ib = i/qk; // block index
 | 
			
		||||
    const int iqs = (i%qk)/qr; // quant index
 | 
			
		||||
    const int iybs = i - i%qk; // y block start index
 | 
			
		||||
    const int y_offset = qr == 1 ? 1 : qk/2;
 | 
			
		||||
 | 
			
		||||
    const block_q4_1 * x = (const block_q4_1 *) vx;
 | 
			
		||||
 | 
			
		||||
    const int i = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
    const float d = x[i].d;
 | 
			
		||||
    const float m = x[i].m;
 | 
			
		||||
 | 
			
		||||
    for (int j = 0; j < qk/2; ++j) {
 | 
			
		||||
        const int x0 = (x[i].qs[j] & 0xf);
 | 
			
		||||
        const int x1 = (x[i].qs[j] >>  4);
 | 
			
		||||
 | 
			
		||||
        y[i*qk + j + 0   ] = x0*d + m;
 | 
			
		||||
        y[i*qk + j + qk/2] = x1*d + m;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
 | 
			
		||||
    static const int qk = QK5_0;
 | 
			
		||||
 | 
			
		||||
    const block_q5_0 * x = (const block_q5_0 *) vx;
 | 
			
		||||
 | 
			
		||||
    const int i = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
    const float d = x[i].d;
 | 
			
		||||
 | 
			
		||||
    uint32_t qh;
 | 
			
		||||
    memcpy(&qh, x[i].qh, sizeof(qh));
 | 
			
		||||
 | 
			
		||||
    for (int j = 0; j < qk/2; ++j) {
 | 
			
		||||
        const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
 | 
			
		||||
        const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
 | 
			
		||||
 | 
			
		||||
        const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
 | 
			
		||||
        const int32_t x1 = ((x[i].qs[j] >>  4) | xh_1) - 16;
 | 
			
		||||
 | 
			
		||||
        y[i*qk + j + 0   ] = x0*d;
 | 
			
		||||
        y[i*qk + j + qk/2] = x1*d;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
 | 
			
		||||
    static const int qk = QK5_1;
 | 
			
		||||
 | 
			
		||||
    const block_q5_1 * x = (const block_q5_1 *) vx;
 | 
			
		||||
 | 
			
		||||
    const int i = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
    const float d = x[i].d;
 | 
			
		||||
    const float m = x[i].m;
 | 
			
		||||
 | 
			
		||||
    uint32_t qh;
 | 
			
		||||
    memcpy(&qh, x[i].qh, sizeof(qh));
 | 
			
		||||
 | 
			
		||||
    for (int j = 0; j < qk/2; ++j) {
 | 
			
		||||
        const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
 | 
			
		||||
        const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
 | 
			
		||||
 | 
			
		||||
        const int x0 = (x[i].qs[j] & 0xf) | xh_0;
 | 
			
		||||
        const int x1 = (x[i].qs[j] >>  4) | xh_1;
 | 
			
		||||
 | 
			
		||||
        y[i*qk + j + 0   ] = x0*d + m;
 | 
			
		||||
        y[i*qk + j + qk/2] = x1*d + m;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
 | 
			
		||||
    static const int qk = QK8_0;
 | 
			
		||||
 | 
			
		||||
    const block_q8_0 * x = (const block_q8_0 *) vx;
 | 
			
		||||
 | 
			
		||||
    const int i = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
    const float d = x[i].d;
 | 
			
		||||
 | 
			
		||||
    for (int j = 0; j < qk; ++j) {
 | 
			
		||||
        y[i*qk + j] = x[i].qs[j]*d;
 | 
			
		||||
    }
 | 
			
		||||
    // dequantize
 | 
			
		||||
    float & v0 = y[iybs + iqs + 0];
 | 
			
		||||
    float & v1 = y[iybs + iqs + y_offset];
 | 
			
		||||
    dequantize_kernel(vx, ib, iqs, v0, v1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
 | 
			
		||||
@@ -308,29 +228,29 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
 | 
			
		||||
    const int nb = k / QK4_0;
 | 
			
		||||
    dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
 | 
			
		||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
 | 
			
		||||
    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 | 
			
		||||
    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
 | 
			
		||||
    const int nb = k / QK4_1;
 | 
			
		||||
    dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
 | 
			
		||||
static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
 | 
			
		||||
    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 | 
			
		||||
    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
 | 
			
		||||
    const int nb = k / QK5_0;
 | 
			
		||||
    dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
 | 
			
		||||
static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
 | 
			
		||||
    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 | 
			
		||||
    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
 | 
			
		||||
    const int nb = k / QK5_1;
 | 
			
		||||
    dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
 | 
			
		||||
static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
 | 
			
		||||
    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 | 
			
		||||
    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
 | 
			
		||||
    const int nb = k / QK8_0;
 | 
			
		||||
    dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
 | 
			
		||||
static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
 | 
			
		||||
    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 | 
			
		||||
    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
 | 
			
		||||
@@ -363,17 +283,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
 | 
			
		||||
        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: optimize
 | 
			
		||||
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
 | 
			
		||||
    const half * x = (const half *) vx;
 | 
			
		||||
 | 
			
		||||
    const int i = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
    y[i] = __half2float(x[i]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
 | 
			
		||||
    convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
 | 
			
		||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
 | 
			
		||||
    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 | 
			
		||||
    dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user