mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	ggml: cache sin/cos for RoPE (#4908)
This commit is contained in:
		
							
								
								
									
										46
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										46
									
								
								ggml.c
									
									
									
									
									
								
							@@ -11638,6 +11638,21 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
 | 
				
			|||||||
    return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
 | 
					    return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static void ggml_rope_cache_init(
 | 
				
			||||||
 | 
					     float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
 | 
				
			||||||
 | 
					     float * cache, float sin_sign, float theta_scale
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
					    float theta = theta_base;
 | 
				
			||||||
 | 
					    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 | 
				
			||||||
 | 
					        rope_yarn(
 | 
				
			||||||
 | 
					            theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
 | 
				
			||||||
 | 
					        );
 | 
				
			||||||
 | 
					        cache[i0 + 1] *= sin_sign;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        theta *= theta_scale;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void ggml_rope_yarn_corr_dims(
 | 
					void ggml_rope_yarn_corr_dims(
 | 
				
			||||||
    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
 | 
					    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
@@ -11720,6 +11735,12 @@ static void ggml_compute_forward_rope_f32(
 | 
				
			|||||||
    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
					    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
				
			||||||
        for (int64_t i2 = 0; i2 < ne2; i2++) {
 | 
					        for (int64_t i2 = 0; i2 < ne2; i2++) {
 | 
				
			||||||
            const int64_t p = pos[i2];
 | 
					            const int64_t p = pos[i2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
 | 
				
			||||||
 | 
					            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
 | 
				
			||||||
 | 
					                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (int64_t i1 = 0; i1 < ne1; i1++) {
 | 
					            for (int64_t i1 = 0; i1 < ne1; i1++) {
 | 
				
			||||||
                if (ir++ < ir0) continue;
 | 
					                if (ir++ < ir0) continue;
 | 
				
			||||||
                if (ir   > ir1) break;
 | 
					                if (ir   > ir1) break;
 | 
				
			||||||
@@ -11753,18 +11774,13 @@ static void ggml_compute_forward_rope_f32(
 | 
				
			|||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if (!is_neox) {
 | 
					                } else if (!is_neox) {
 | 
				
			||||||
                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 | 
					                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 | 
				
			||||||
                        float cos_theta, sin_theta;
 | 
					                        const float cos_theta = cache[i0 + 0];
 | 
				
			||||||
                        rope_yarn(
 | 
					                        const float sin_theta = cache[i0 + 1];
 | 
				
			||||||
                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
 | 
					 | 
				
			||||||
                        );
 | 
					 | 
				
			||||||
                        sin_theta *= sin_sign;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        // zeta scaling for xPos only:
 | 
					                        // zeta scaling for xPos only:
 | 
				
			||||||
                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
 | 
					                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
 | 
				
			||||||
                        if (xpos_down) zeta = 1.0f / zeta;
 | 
					                        if (xpos_down) zeta = 1.0f / zeta;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        theta_base *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
					                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
				
			||||||
                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
					                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -11888,6 +11904,12 @@ static void ggml_compute_forward_rope_f16(
 | 
				
			|||||||
    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
					    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
				
			||||||
        for (int64_t i2 = 0; i2 < ne2; i2++) {
 | 
					        for (int64_t i2 = 0; i2 < ne2; i2++) {
 | 
				
			||||||
            const int64_t p = pos[i2];
 | 
					            const int64_t p = pos[i2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
 | 
				
			||||||
 | 
					            if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
 | 
				
			||||||
 | 
					                ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (int64_t i1 = 0; i1 < ne1; i1++) {
 | 
					            for (int64_t i1 = 0; i1 < ne1; i1++) {
 | 
				
			||||||
                if (ir++ < ir0) continue;
 | 
					                if (ir++ < ir0) continue;
 | 
				
			||||||
                if (ir   > ir1) break;
 | 
					                if (ir   > ir1) break;
 | 
				
			||||||
@@ -11921,13 +11943,8 @@ static void ggml_compute_forward_rope_f16(
 | 
				
			|||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if (!is_neox) {
 | 
					                } else if (!is_neox) {
 | 
				
			||||||
                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 | 
					                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 | 
				
			||||||
                        float cos_theta, sin_theta;
 | 
					                        const float cos_theta = cache[i0 + 0];
 | 
				
			||||||
                        rope_yarn(
 | 
					                        const float sin_theta = cache[i0 + 1];
 | 
				
			||||||
                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
 | 
					 | 
				
			||||||
                        );
 | 
					 | 
				
			||||||
                        sin_theta *= sin_sign;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        theta_base *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
					                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
				
			||||||
                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
					                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
				
			||||||
@@ -16722,6 +16739,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
 | 
				
			|||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } break;
 | 
					                } break;
 | 
				
			||||||
            case GGML_OP_SOFT_MAX:
 | 
					            case GGML_OP_SOFT_MAX:
 | 
				
			||||||
 | 
					            case GGML_OP_ROPE:
 | 
				
			||||||
                {
 | 
					                {
 | 
				
			||||||
                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
 | 
					                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
 | 
				
			||||||
                } break;
 | 
					                } break;
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user