mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml : sync (ggml_conv_2d, fix mul_mat bug, CUDA GLM rope)
This commit is contained in:
		
							
								
								
									
										54
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										54
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -1667,6 +1667,40 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c | ||||
|     dst[i + 1] = x0*sin_theta + x1*cos_theta; | ||||
| } | ||||
|  | ||||
| static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { | ||||
|     const int col = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|     const int half_n_dims = ncols/4; | ||||
|  | ||||
|     if (col >= half_n_dims) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     const int row = blockDim.y*blockIdx.y + threadIdx.y; | ||||
|     const int i = row*ncols + col; | ||||
|  | ||||
|     const float col_theta_scale = powf(theta_scale, col); | ||||
|  | ||||
|     const float theta = p*col_theta_scale; | ||||
|     const float sin_theta = sinf(theta); | ||||
|     const float cos_theta = cosf(theta); | ||||
|  | ||||
|     const float x0 = x[i + 0]; | ||||
|     const float x1 = x[i + half_n_dims]; | ||||
|  | ||||
|     dst[i + 0]           = x0*cos_theta - x1*sin_theta; | ||||
|     dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; | ||||
|  | ||||
|     const float block_theta = block_p*col_theta_scale; | ||||
|     const float sin_block_theta = sinf(block_theta); | ||||
|     const float cos_block_theta = cosf(block_theta); | ||||
|  | ||||
|     const float x2 = x[i + half_n_dims * 2]; | ||||
|     const float x3 = x[i + half_n_dims * 3]; | ||||
|  | ||||
|     dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; | ||||
|     dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; | ||||
| } | ||||
|  | ||||
| static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { | ||||
|     const int col = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|     const int row = blockDim.y*blockIdx.y + threadIdx.y; | ||||
| @@ -2064,6 +2098,14 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i | ||||
|     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale); | ||||
| } | ||||
|  | ||||
| static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { | ||||
|     GGML_ASSERT(nrows % 4 == 0); | ||||
|     const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1); | ||||
|     const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE); | ||||
|     const dim3 block_nums(num_blocks_x, nrows, 1); | ||||
|     rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale); | ||||
| } | ||||
|  | ||||
| static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { | ||||
|     const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); | ||||
|     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; | ||||
| @@ -2618,13 +2660,21 @@ inline void ggml_cuda_op_rope( | ||||
|     const int n_past = ((int32_t *) src1->data)[0]; | ||||
|     const int n_dims = ((int32_t *) src1->data)[1]; | ||||
|     const int mode   = ((int32_t *) src1->data)[2]; | ||||
|     GGML_ASSERT(mode == 0); | ||||
|     const int n_ctx  = ((int32_t *) src1->data)[3]; | ||||
|  | ||||
|     const float theta_scale = powf(10000.0, -2.0f/n_dims); | ||||
|     const float p = ((mode & 1) == 0 ? n_past + i02 : i02); | ||||
|  | ||||
|     bool is_glm = mode & 4; | ||||
|  | ||||
|     // compute | ||||
|     rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); | ||||
|     if (is_glm) { | ||||
|         const float id_p = min(p, n_ctx - 2.f); | ||||
|         const float block_p = max(p - (n_ctx - 2.f), 0.f); | ||||
|         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); | ||||
|     } else { | ||||
|         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); | ||||
|     } | ||||
|  | ||||
|     (void) dst; | ||||
|     (void) src0_ddq_i; | ||||
|   | ||||
							
								
								
									
										99
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										99
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -10684,6 +10684,8 @@ static void ggml_compute_forward_mul_mat( | ||||
|  | ||||
|     const enum ggml_type type = src0->type; | ||||
|  | ||||
|     const bool src1_cont = ggml_is_contiguous(src1); | ||||
|  | ||||
|     ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot; | ||||
|     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type; | ||||
|     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; | ||||
| @@ -10747,7 +10749,7 @@ static void ggml_compute_forward_mul_mat( | ||||
|                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||
|  | ||||
|                 if (type != GGML_TYPE_F32) { | ||||
|                     float * const wdata = params->wdata; | ||||
|                             float * const wdata    = params->wdata; | ||||
|                     ggml_to_float_t const to_float = type_traits[type].to_float; | ||||
|  | ||||
|                     size_t id = 0; | ||||
| @@ -10805,7 +10807,7 @@ static void ggml_compute_forward_mul_mat( | ||||
|     // src1 rows | ||||
|     const int64_t nr1 = ne11*ne12*ne13; | ||||
|  | ||||
|     void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; | ||||
|     const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata; | ||||
|     const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; | ||||
|  | ||||
|     for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { | ||||
| @@ -10828,7 +10830,15 @@ static void ggml_compute_forward_mul_mat( | ||||
|         const int64_t i3 = i13; | ||||
|  | ||||
|         const char * src0_row = (const char *) src0->data + (  0 + i02*nb02 + i03*nb03     ); | ||||
|         const char * src1_col = (const char *)      wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size; | ||||
|  | ||||
|         // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides | ||||
|         //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using | ||||
|         //       the original src1 data pointer, so we should index using the indices directly | ||||
|         // TODO: this is a bit of a hack, we should probably have a better way to handle this | ||||
|         const char * src1_col = (const char *) wdata + | ||||
|             (src1_cont || src1->type != vec_dot_type | ||||
|              ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size | ||||
|              : (i11*nb11 + i12*nb12 + i13*nb13)); | ||||
|  | ||||
|         float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); | ||||
|  | ||||
| @@ -12982,12 +12992,13 @@ static void ggml_compute_forward_conv_1d( | ||||
|     }; | ||||
| } | ||||
|  | ||||
| // ggml_compute_forward_conv_2d_sk_p0 | ||||
| // ggml_compute_forward_conv_2d | ||||
|  | ||||
| static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( | ||||
| static void ggml_compute_forward_conv_2d_f16_f32( | ||||
|         const struct ggml_compute_params * params, | ||||
|         const struct ggml_tensor * src0, | ||||
|         const struct ggml_tensor * src1, | ||||
|         const struct ggml_tensor * opt0, | ||||
|               struct ggml_tensor * dst) { | ||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||||
|     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||
| @@ -13007,28 +13018,37 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( | ||||
|     // size of the convolution row - the kernel size unrolled across all channels | ||||
|     const int ew0 = nk0*nk1*ne02; | ||||
|  | ||||
|     const int32_t s0 = ((const int32_t*)(opt0->data))[0]; | ||||
|     const int32_t s1 = ((const int32_t*)(opt0->data))[1]; | ||||
|     const int32_t p0 = ((const int32_t*)(opt0->data))[2]; | ||||
|     const int32_t p1 = ((const int32_t*)(opt0->data))[3]; | ||||
|     const int32_t d0 = ((const int32_t*)(opt0->data))[4]; | ||||
|     const int32_t d1 = ((const int32_t*)(opt0->data))[5]; | ||||
|  | ||||
|     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); | ||||
|     GGML_ASSERT(nb10 == sizeof(float)); | ||||
|  | ||||
|     if (params->type == GGML_TASK_INIT) { | ||||
|         // TODO: fix this memset (wsize is overestimated) | ||||
|         memset(params->wdata, 0, params->wsize); | ||||
|  | ||||
|         // prepare source data (src1) | ||||
|         { | ||||
|             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; | ||||
|  | ||||
|             for (int i13 = 0; i13 < ne13; i13++) { | ||||
|                 for (int i12 = 0; i12 < ne12; i12++) { | ||||
|                     const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12); | ||||
|                     ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0); | ||||
|             for (int i12 = 0; i12 < ne12; i12++) { | ||||
|                 const float * const src = (float *)((char *) src1->data + i12*nb12); | ||||
|                 ggml_fp16_t * dst_data = wdata; | ||||
|  | ||||
|                     for (int i1 = 0; i1 < ne1; i1++) { | ||||
|                         for (int i0 = 0; i0 < ne0; i0++) { | ||||
|                             for (int ik1 = 0; ik1 < nk1; ik1++) { | ||||
|                                 for (int ik0 = 0; ik0 < nk0; ik0++) { | ||||
|                 for (int i1 = 0; i1 < ne1; i1++) { | ||||
|                     for (int i0 = 0; i0 < ne0; i0++) { | ||||
|                         for (int ik1 = 0; ik1 < nk1; ik1++) { | ||||
|                             for (int ik0 = 0; ik0 < nk0; ik0++) { | ||||
|                                 const int idx0 = i0*s0 + ik0*d0 - p0; | ||||
|                                 const int idx1 = i1*s1 + ik1*d1 - p1; | ||||
|  | ||||
|                                 if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) { | ||||
|                                     dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = | ||||
|                                         GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]); | ||||
|                                         GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]); | ||||
|                                 } | ||||
|                             } | ||||
|                         } | ||||
| @@ -13071,19 +13091,21 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void ggml_compute_forward_conv_2d_sk_p0( | ||||
| static void ggml_compute_forward_conv_2d( | ||||
|         const struct ggml_compute_params * params, | ||||
|         const struct ggml_tensor * src0, | ||||
|         const struct ggml_tensor * src1, | ||||
|         struct ggml_tensor * dst) { | ||||
|         const struct ggml_tensor * opt0, | ||||
|         struct ggml_tensor * dst | ||||
|         ) { | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F16: | ||||
|             { | ||||
|                 ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst); | ||||
|                 ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst); | ||||
|             } break; | ||||
|         case GGML_TYPE_F32: | ||||
|             { | ||||
|                 //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst); | ||||
|                 //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst); | ||||
|                 GGML_ASSERT(false); | ||||
|             } break; | ||||
|         default: | ||||
| @@ -13093,32 +13115,6 @@ static void ggml_compute_forward_conv_2d_sk_p0( | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ggml_compute_forward_conv_2d | ||||
|  | ||||
| static void ggml_compute_forward_conv_2d( | ||||
|     const struct ggml_compute_params* params, | ||||
|     const struct ggml_tensor* src0, | ||||
|     const struct ggml_tensor* src1, | ||||
|     const struct ggml_tensor* opt0, | ||||
|     struct ggml_tensor* dst) { | ||||
|     const int32_t s0 = ((const int32_t*)(opt0->data))[0]; | ||||
|     const int32_t s1 = ((const int32_t*)(opt0->data))[1]; | ||||
|     const int32_t p0 = ((const int32_t*)(opt0->data))[2]; | ||||
|     const int32_t p1 = ((const int32_t*)(opt0->data))[3]; | ||||
|     const int32_t d0 = ((const int32_t*)(opt0->data))[4]; | ||||
|     const int32_t d1 = ((const int32_t*)(opt0->data))[5]; | ||||
|     GGML_ASSERT(d0 == 1); // dilation not supported | ||||
|     GGML_ASSERT(d1 == 1); | ||||
|     GGML_ASSERT(p0 == 0); // padding not supported | ||||
|     GGML_ASSERT(p1 == 0); | ||||
|  | ||||
|     if (s0 == src0->ne[0] && s1 == src0->ne[1]) { | ||||
|         ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst); | ||||
|     } else { | ||||
|         GGML_ASSERT(false); // only stride equal to kernel size is supported | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ggml_compute_forward_pool_1d_sk_p0 | ||||
|  | ||||
| static void ggml_compute_forward_pool_1d_sk_p0( | ||||
| @@ -16575,19 +16571,22 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { | ||||
|                     const int64_t ne11 = node->src[1]->ne[1]; // H | ||||
|                     const int64_t ne12 = node->src[1]->ne[2]; // C | ||||
|  | ||||
|                     const int64_t ne0 = node->ne[0]; | ||||
|                     const int64_t ne1 = node->ne[1]; | ||||
|                     const int64_t ne2 = node->ne[2]; | ||||
|                     const int64_t nk = ne00*ne01; | ||||
|                     const int64_t ew0 = nk * ne02; | ||||
|  | ||||
|                     UNUSED(ne02); | ||||
|                     UNUSED(ne03); | ||||
|                     UNUSED(nk); | ||||
|                     UNUSED(ne2); | ||||
|  | ||||
|                     size_t cur = 0; | ||||
|  | ||||
|                     if (node->src[0]->type == GGML_TYPE_F16 && | ||||
|                             node->src[1]->type == GGML_TYPE_F32) { | ||||
|                         cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12); | ||||
|                         node->src[1]->type == GGML_TYPE_F32) { | ||||
|                         cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); | ||||
|                     } else if (node->src[0]->type == GGML_TYPE_F32 && | ||||
|                             node->src[1]->type == GGML_TYPE_F32) { | ||||
|                                node->src[1]->type == GGML_TYPE_F32) { | ||||
|                         cur = sizeof(float)*      (ne10*ne11*ne12); | ||||
|                     } else { | ||||
|                         GGML_ASSERT(false); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov