mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml : fix YARN + add tests + add asserts (#7617)
* tests : add rope tests ggml-ci * ggml : fixes (hopefully) ggml-ci * tests : add non-cont tests ggml-ci * cuda : add asserts for rope/norm + fix DS2 ggml-ci * ggml : assert contiguousness * tests : reduce RoPE tests ggml-ci
This commit is contained in:
		@@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
    float * dst_d = (float *)dst->data;
 | 
			
		||||
    cudaStream_t stream = ctx.stream();
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(ggml_is_contiguous(src0));
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 | 
			
		||||
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 | 
			
		||||
 | 
			
		||||
@@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 | 
			
		||||
    float * dst_d = (float *)dst->data;
 | 
			
		||||
    cudaStream_t stream = ctx.stream();
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(ggml_is_contiguous(src0));
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 | 
			
		||||
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 | 
			
		||||
 | 
			
		||||
@@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
    float * dst_d = (float *)dst->data;
 | 
			
		||||
    cudaStream_t stream = ctx.stream();
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(ggml_is_contiguous(src0));
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 | 
			
		||||
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -61,7 +61,7 @@ static __global__ void rope(
 | 
			
		||||
template<typename T, bool has_pos, bool has_freq_facs>
 | 
			
		||||
static __global__ void rope_neox(
 | 
			
		||||
    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
 | 
			
		||||
    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
 | 
			
		||||
    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
 | 
			
		||||
) {
 | 
			
		||||
    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 | 
			
		||||
 | 
			
		||||
@@ -85,15 +85,13 @@ static __global__ void rope_neox(
 | 
			
		||||
    const int i  = row*ncols + ib*n_dims + ic/2;
 | 
			
		||||
    const int i2 = row/p_delta_rows;
 | 
			
		||||
 | 
			
		||||
    float cur_rot = inv_ndims * ic - ib;
 | 
			
		||||
 | 
			
		||||
    const int p = has_pos ? pos[i2] : 0;
 | 
			
		||||
    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
 | 
			
		||||
 | 
			
		||||
    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
 | 
			
		||||
    const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
 | 
			
		||||
 | 
			
		||||
    float cos_theta, sin_theta;
 | 
			
		||||
    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 | 
			
		||||
    rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
 | 
			
		||||
 | 
			
		||||
    const float x0 = x[i + 0];
 | 
			
		||||
    const float x1 = x[i + n_dims/2];
 | 
			
		||||
@@ -174,30 +172,29 @@ static void rope_neox_cuda(
 | 
			
		||||
    const dim3 block_nums(nrows, num_blocks_x, 1);
 | 
			
		||||
 | 
			
		||||
    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 | 
			
		||||
    const float inv_ndims = -1.0f / n_dims;
 | 
			
		||||
 | 
			
		||||
    if (pos == nullptr) {
 | 
			
		||||
        if (freq_factors == nullptr) {
 | 
			
		||||
            rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
 | 
			
		||||
                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
 | 
			
		||||
                theta_scale, inv_ndims, freq_factors
 | 
			
		||||
                theta_scale, freq_factors
 | 
			
		||||
                );
 | 
			
		||||
        } else {
 | 
			
		||||
            rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
 | 
			
		||||
                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
 | 
			
		||||
                theta_scale, inv_ndims, freq_factors
 | 
			
		||||
                theta_scale, freq_factors
 | 
			
		||||
                );
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        if (freq_factors == nullptr) {
 | 
			
		||||
            rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
 | 
			
		||||
                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
 | 
			
		||||
                theta_scale, inv_ndims, freq_factors
 | 
			
		||||
                theta_scale, freq_factors
 | 
			
		||||
                );
 | 
			
		||||
        } else {
 | 
			
		||||
            rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
 | 
			
		||||
                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
 | 
			
		||||
                theta_scale, inv_ndims, freq_factors
 | 
			
		||||
                theta_scale, freq_factors
 | 
			
		||||
                );
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@@ -254,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
    float * dst_d = (float *)dst->data;
 | 
			
		||||
    cudaStream_t stream = ctx.stream();
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(ggml_is_contiguous(src0));
 | 
			
		||||
    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
 | 
			
		||||
    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
 | 
			
		||||
    GGML_ASSERT(src0->type == dst->type);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user