mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	ggml : fix backward rope after YaRN (#3974)
* fix backward process of rope rope backward process was broken after YaRN RoPE (#2268) implementation, due to missing changes in backward functions. the code for the backward process is nearly identically to the forward process: the only difference is the sign of the sin-values. to avoid future regressions remove the near-duplicate backward functions and reuse the forward code: for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`. the sin-values will be negated when forward is false. * fix finetune rope call to use correct default attn_factor of 1.0f * remove unused `ggml_rope_xpos_back` it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants. * fix comments explaining the sinus sign in ggml_forward_rope * add missing function arguments in declaration * fix function argument type in declaration
This commit is contained in:
		@@ -643,7 +643,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return ggml_rope_custom(ctx,
 | 
					        return ggml_rope_custom(ctx,
 | 
				
			||||||
            t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
 | 
					            t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
 | 
				
			||||||
            rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
 | 
					            rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
 | 
				
			||||||
        );
 | 
					        );
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										330
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										330
									
								
								ggml.c
									
									
									
									
									
								
							@@ -4970,8 +4970,13 @@ struct ggml_tensor * ggml_rope_back(
 | 
				
			|||||||
        int                   n_dims,
 | 
					        int                   n_dims,
 | 
				
			||||||
        int                   mode,
 | 
					        int                   mode,
 | 
				
			||||||
        int                   n_ctx,
 | 
					        int                   n_ctx,
 | 
				
			||||||
 | 
					        int                   n_orig_ctx,
 | 
				
			||||||
        float                 freq_base,
 | 
					        float                 freq_base,
 | 
				
			||||||
        float                 freq_scale,
 | 
					        float                 freq_scale,
 | 
				
			||||||
 | 
					        float                 ext_factor,
 | 
				
			||||||
 | 
					        float                 attn_factor,
 | 
				
			||||||
 | 
					        float                 beta_fast,
 | 
				
			||||||
 | 
					        float                 beta_slow,
 | 
				
			||||||
        float                 xpos_base,
 | 
					        float                 xpos_base,
 | 
				
			||||||
        bool                  xpos_down) {
 | 
					        bool                  xpos_down) {
 | 
				
			||||||
    GGML_ASSERT(ggml_is_vector(b));
 | 
					    GGML_ASSERT(ggml_is_vector(b));
 | 
				
			||||||
@@ -4988,11 +4993,15 @@ struct ggml_tensor * ggml_rope_back(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 | 
					    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
 | 
					    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
 | 
				
			||||||
    memcpy(params + 4, &freq_base,  sizeof(float));
 | 
					    memcpy(params +  5, &freq_base,    sizeof(float));
 | 
				
			||||||
    memcpy(params + 5, &freq_scale, sizeof(float));
 | 
					    memcpy(params +  6, &freq_scale,   sizeof(float));
 | 
				
			||||||
    memcpy(params + 6, &xpos_base,  sizeof(float));
 | 
					    memcpy(params +  7, &ext_factor,   sizeof(float));
 | 
				
			||||||
    memcpy(params + 7, &xpos_down,  sizeof(bool));
 | 
					    memcpy(params +  8, &attn_factor,  sizeof(float));
 | 
				
			||||||
 | 
					    memcpy(params +  9, &beta_fast,    sizeof(float));
 | 
				
			||||||
 | 
					    memcpy(params + 10, &beta_slow,    sizeof(float));
 | 
				
			||||||
 | 
					    memcpy(params + 11, &xpos_base,    sizeof(float));
 | 
				
			||||||
 | 
					    memcpy(params + 12, &xpos_down,    sizeof(bool));
 | 
				
			||||||
    ggml_set_op_params(result, params, sizeof(params));
 | 
					    ggml_set_op_params(result, params, sizeof(params));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result->op   = GGML_OP_ROPE_BACK;
 | 
					    result->op   = GGML_OP_ROPE_BACK;
 | 
				
			||||||
@@ -10974,7 +10983,8 @@ static void ggml_compute_forward_rope_f32(
 | 
				
			|||||||
        const struct ggml_compute_params * params,
 | 
					        const struct ggml_compute_params * params,
 | 
				
			||||||
        const struct ggml_tensor * src0,
 | 
					        const struct ggml_tensor * src0,
 | 
				
			||||||
        const struct ggml_tensor * src1,
 | 
					        const struct ggml_tensor * src1,
 | 
				
			||||||
        struct ggml_tensor * dst) {
 | 
					        struct ggml_tensor * dst,
 | 
				
			||||||
 | 
					        const bool forward) {
 | 
				
			||||||
    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
					    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
				
			||||||
        return;
 | 
					        return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -11033,6 +11043,11 @@ static void ggml_compute_forward_rope_f32(
 | 
				
			|||||||
    const bool is_neox = mode & 2;
 | 
					    const bool is_neox = mode & 2;
 | 
				
			||||||
    const bool is_glm  = mode & 4;
 | 
					    const bool is_glm  = mode & 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // backward process uses inverse rotation by cos and sin.
 | 
				
			||||||
 | 
					    // cos and sin build a rotation matrix, where the inverse is the transpose.
 | 
				
			||||||
 | 
					    // this essentially just switches the sign of sin.
 | 
				
			||||||
 | 
					    const float sin_sign = forward ? 1.0f : -1.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int32_t * pos = (const int32_t *) src1->data;
 | 
					    const int32_t * pos = (const int32_t *) src1->data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
					    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
				
			||||||
@@ -11049,9 +11064,9 @@ static void ggml_compute_forward_rope_f32(
 | 
				
			|||||||
                    float block_theta = MAX(p - (n_ctx - 2), 0);
 | 
					                    float block_theta = MAX(p - (n_ctx - 2), 0);
 | 
				
			||||||
                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
 | 
					                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
 | 
				
			||||||
                        const float cos_theta = cosf(theta_base);
 | 
					                        const float cos_theta = cosf(theta_base);
 | 
				
			||||||
                        const float sin_theta = sinf(theta_base);
 | 
					                        const float sin_theta = sinf(theta_base) * sin_sign;
 | 
				
			||||||
                        const float cos_block_theta = cosf(block_theta);
 | 
					                        const float cos_block_theta = cosf(block_theta);
 | 
				
			||||||
                        const float sin_block_theta = sinf(block_theta);
 | 
					                        const float sin_block_theta = sinf(block_theta) * sin_sign;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        theta_base *= theta_scale;
 | 
					                        theta_base *= theta_scale;
 | 
				
			||||||
                        block_theta *= theta_scale;
 | 
					                        block_theta *= theta_scale;
 | 
				
			||||||
@@ -11075,6 +11090,7 @@ static void ggml_compute_forward_rope_f32(
 | 
				
			|||||||
                        rope_yarn(
 | 
					                        rope_yarn(
 | 
				
			||||||
                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
 | 
					                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
 | 
				
			||||||
                        );
 | 
					                        );
 | 
				
			||||||
 | 
					                        sin_theta *= sin_sign;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        // zeta scaling for xPos only:
 | 
					                        // zeta scaling for xPos only:
 | 
				
			||||||
                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
 | 
					                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
 | 
				
			||||||
@@ -11105,6 +11121,7 @@ static void ggml_compute_forward_rope_f32(
 | 
				
			|||||||
                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
 | 
					                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
 | 
				
			||||||
                                &cos_theta, &sin_theta
 | 
					                                &cos_theta, &sin_theta
 | 
				
			||||||
                            );
 | 
					                            );
 | 
				
			||||||
 | 
					                            sin_theta *= sin_sign;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            theta_base *= theta_scale;
 | 
					                            theta_base *= theta_scale;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -11130,7 +11147,8 @@ static void ggml_compute_forward_rope_f16(
 | 
				
			|||||||
        const struct ggml_compute_params * params,
 | 
					        const struct ggml_compute_params * params,
 | 
				
			||||||
        const struct ggml_tensor * src0,
 | 
					        const struct ggml_tensor * src0,
 | 
				
			||||||
        const struct ggml_tensor * src1,
 | 
					        const struct ggml_tensor * src1,
 | 
				
			||||||
        struct ggml_tensor * dst) {
 | 
					        struct ggml_tensor * dst,
 | 
				
			||||||
 | 
					        const bool forward) {
 | 
				
			||||||
    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
					    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
				
			||||||
        return;
 | 
					        return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -11182,6 +11200,11 @@ static void ggml_compute_forward_rope_f16(
 | 
				
			|||||||
    const bool is_neox = mode & 2;
 | 
					    const bool is_neox = mode & 2;
 | 
				
			||||||
    const bool is_glm  = mode & 4;
 | 
					    const bool is_glm  = mode & 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // backward process uses inverse rotation by cos and sin.
 | 
				
			||||||
 | 
					    // cos and sin build a rotation matrix, where the inverse is the transpose.
 | 
				
			||||||
 | 
					    // this essentially just switches the sign of sin.
 | 
				
			||||||
 | 
					    const float sin_sign = forward ? 1.0f : -1.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int32_t * pos = (const int32_t *) src1->data;
 | 
					    const int32_t * pos = (const int32_t *) src1->data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
					    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
				
			||||||
@@ -11198,9 +11221,9 @@ static void ggml_compute_forward_rope_f16(
 | 
				
			|||||||
                    float block_theta = MAX(p - (n_ctx - 2), 0);
 | 
					                    float block_theta = MAX(p - (n_ctx - 2), 0);
 | 
				
			||||||
                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
 | 
					                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
 | 
				
			||||||
                        const float cos_theta = cosf(theta_base);
 | 
					                        const float cos_theta = cosf(theta_base);
 | 
				
			||||||
                        const float sin_theta = sinf(theta_base);
 | 
					                        const float sin_theta = sinf(theta_base) * sin_sign;
 | 
				
			||||||
                        const float cos_block_theta = cosf(block_theta);
 | 
					                        const float cos_block_theta = cosf(block_theta);
 | 
				
			||||||
                        const float sin_block_theta = sinf(block_theta);
 | 
					                        const float sin_block_theta = sinf(block_theta) * sin_sign;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        theta_base *= theta_scale;
 | 
					                        theta_base *= theta_scale;
 | 
				
			||||||
                        block_theta *= theta_scale;
 | 
					                        block_theta *= theta_scale;
 | 
				
			||||||
@@ -11224,6 +11247,7 @@ static void ggml_compute_forward_rope_f16(
 | 
				
			|||||||
                        rope_yarn(
 | 
					                        rope_yarn(
 | 
				
			||||||
                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
 | 
					                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
 | 
				
			||||||
                        );
 | 
					                        );
 | 
				
			||||||
 | 
					                        sin_theta *= sin_sign;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        theta_base *= theta_scale;
 | 
					                        theta_base *= theta_scale;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -11250,6 +11274,7 @@ static void ggml_compute_forward_rope_f16(
 | 
				
			|||||||
                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
 | 
					                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
 | 
				
			||||||
                                &cos_theta, &sin_theta
 | 
					                                &cos_theta, &sin_theta
 | 
				
			||||||
                            );
 | 
					                            );
 | 
				
			||||||
 | 
					                            sin_theta *= sin_sign;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            theta_base *= theta_scale;
 | 
					                            theta_base *= theta_scale;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -11279,11 +11304,11 @@ static void ggml_compute_forward_rope(
 | 
				
			|||||||
    switch (src0->type) {
 | 
					    switch (src0->type) {
 | 
				
			||||||
        case GGML_TYPE_F16:
 | 
					        case GGML_TYPE_F16:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                ggml_compute_forward_rope_f16(params, src0, src1, dst);
 | 
					                ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        case GGML_TYPE_F32:
 | 
					        case GGML_TYPE_F32:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                ggml_compute_forward_rope_f32(params, src0, src1, dst);
 | 
					                ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        default:
 | 
					        default:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
@@ -11294,216 +11319,6 @@ static void ggml_compute_forward_rope(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// ggml_compute_forward_rope_back
 | 
					// ggml_compute_forward_rope_back
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static void ggml_compute_forward_rope_back_f32(
 | 
					 | 
				
			||||||
        const struct ggml_compute_params * params,
 | 
					 | 
				
			||||||
        const struct ggml_tensor * src0,
 | 
					 | 
				
			||||||
        const struct ggml_tensor * src1,
 | 
					 | 
				
			||||||
        struct ggml_tensor * dst) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
					 | 
				
			||||||
        return;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // y = rope(x, src1)
 | 
					 | 
				
			||||||
    // dx = rope_back(dy, src1)
 | 
					 | 
				
			||||||
    // src0 is dy, src1 contains options
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    float freq_base;
 | 
					 | 
				
			||||||
    float freq_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // these two only relevant for xPos RoPE:
 | 
					 | 
				
			||||||
    float xpos_base;
 | 
					 | 
				
			||||||
    bool xpos_down;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //const int n_past = ((int32_t *) dst->op_params)[0];
 | 
					 | 
				
			||||||
    const int n_dims = ((int32_t *) dst->op_params)[1];
 | 
					 | 
				
			||||||
    const int mode   = ((int32_t *) dst->op_params)[2];
 | 
					 | 
				
			||||||
    const int n_ctx  = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
 | 
					 | 
				
			||||||
    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
 | 
					 | 
				
			||||||
    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 | 
					 | 
				
			||||||
    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
 | 
					 | 
				
			||||||
    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    GGML_TENSOR_UNARY_OP_LOCALS
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
 | 
					 | 
				
			||||||
    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    assert(nb0 == sizeof(float));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int ith = params->ith;
 | 
					 | 
				
			||||||
    const int nth = params->nth;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int nr = ggml_nrows(dst);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // rows per thread
 | 
					 | 
				
			||||||
    const int dr = (nr + nth - 1)/nth;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // row range for this thread
 | 
					 | 
				
			||||||
    const int ir0 = dr*ith;
 | 
					 | 
				
			||||||
    const int ir1 = MIN(ir0 + dr, nr);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // row index used to determine which thread to use
 | 
					 | 
				
			||||||
    int ir = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const bool is_neox = mode & 2;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int32_t * pos = (const int32_t *) src1->data;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
					 | 
				
			||||||
        for (int64_t i2 = 0; i2 < ne2; i2++) {
 | 
					 | 
				
			||||||
            const int64_t p = pos[i2];
 | 
					 | 
				
			||||||
            for (int64_t i1 = 0; i1 < ne1; i1++) {
 | 
					 | 
				
			||||||
                if (ir++ < ir0) continue;
 | 
					 | 
				
			||||||
                if (ir   > ir1) break;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                float theta_base = freq_scale * (float)p;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (!is_neox) {
 | 
					 | 
				
			||||||
                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 | 
					 | 
				
			||||||
                        const float cos_theta = cosf(theta_base);
 | 
					 | 
				
			||||||
                        const float sin_theta = sinf(theta_base);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        // zeta scaling for xPos only:
 | 
					 | 
				
			||||||
                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
 | 
					 | 
				
			||||||
                        if (xpos_down) zeta = 1.0f / zeta;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        theta_base *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
					 | 
				
			||||||
                              float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        const float dy0 = dy[0];
 | 
					 | 
				
			||||||
                        const float dy1 = dy[1];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        dx[0] =   dy0*cos_theta*zeta + dy1*sin_theta*zeta;
 | 
					 | 
				
			||||||
                        dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                } else {
 | 
					 | 
				
			||||||
                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
 | 
					 | 
				
			||||||
                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
 | 
					 | 
				
			||||||
                            const float cos_theta = cosf(theta_base);
 | 
					 | 
				
			||||||
                            const float sin_theta = sinf(theta_base);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            theta_base *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            const int64_t i0 = ib*n_dims + ic/2;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
					 | 
				
			||||||
                                  float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            const float dy0 = dy[0];
 | 
					 | 
				
			||||||
                            const float dy1 = dy[n_dims/2];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            dx[0]        =   dy0*cos_theta + dy1*sin_theta;
 | 
					 | 
				
			||||||
                            dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static void ggml_compute_forward_rope_back_f16(
 | 
					 | 
				
			||||||
        const struct ggml_compute_params * params,
 | 
					 | 
				
			||||||
        const struct ggml_tensor * src0,
 | 
					 | 
				
			||||||
        const struct ggml_tensor * src1,
 | 
					 | 
				
			||||||
        struct ggml_tensor * dst) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
					 | 
				
			||||||
        return;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // y = rope(x, src1)
 | 
					 | 
				
			||||||
    // dx = rope_back(dy, src1)
 | 
					 | 
				
			||||||
    // src0 is dy, src1 contains options
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //const int n_past = ((int32_t *) dst->op_params)[0];
 | 
					 | 
				
			||||||
    const int n_dims = ((int32_t *) dst->op_params)[1];
 | 
					 | 
				
			||||||
    const int mode   = ((int32_t *) dst->op_params)[2];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    GGML_TENSOR_UNARY_OP_LOCALS
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
 | 
					 | 
				
			||||||
    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    assert(nb0 == sizeof(ggml_fp16_t));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int ith = params->ith;
 | 
					 | 
				
			||||||
    const int nth = params->nth;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int nr = ggml_nrows(dst);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // rows per thread
 | 
					 | 
				
			||||||
    const int dr = (nr + nth - 1)/nth;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // row range for this thread
 | 
					 | 
				
			||||||
    const int ir0 = dr*ith;
 | 
					 | 
				
			||||||
    const int ir1 = MIN(ir0 + dr, nr);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // row index used to determine which thread to use
 | 
					 | 
				
			||||||
    int ir = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const float theta_scale = powf(10000.0, -2.0f/n_dims);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const bool is_neox = mode & 2;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int32_t * pos = (const int32_t *) src1->data;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (int64_t i3 = 0; i3 < ne3; i3++) {
 | 
					 | 
				
			||||||
        for (int64_t i2 = 0; i2 < ne2; i2++) {
 | 
					 | 
				
			||||||
            const int64_t p = pos[i2];
 | 
					 | 
				
			||||||
            for (int64_t i1 = 0; i1 < ne1; i1++) {
 | 
					 | 
				
			||||||
                if (ir++ < ir0) continue;
 | 
					 | 
				
			||||||
                if (ir   > ir1) break;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                float theta_base = (float)p;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (!is_neox) {
 | 
					 | 
				
			||||||
                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 | 
					 | 
				
			||||||
                        const float cos_theta = cosf(theta_base);
 | 
					 | 
				
			||||||
                        const float sin_theta = sinf(theta_base);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        theta_base *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
					 | 
				
			||||||
                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        const float dy0 = GGML_FP16_TO_FP32(dy[0]);
 | 
					 | 
				
			||||||
                        const float dy1 = GGML_FP16_TO_FP32(dy[1]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
 | 
					 | 
				
			||||||
                        dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                } else {
 | 
					 | 
				
			||||||
                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
 | 
					 | 
				
			||||||
                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
 | 
					 | 
				
			||||||
                            const float cos_theta = cosf(theta_base);
 | 
					 | 
				
			||||||
                            const float sin_theta = sinf(theta_base);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            theta_base *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            const int64_t i0 = ib*n_dims + ic/2;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 | 
					 | 
				
			||||||
                                  ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            const float dy0 = GGML_FP16_TO_FP32(dy[0]);
 | 
					 | 
				
			||||||
                            const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
 | 
					 | 
				
			||||||
                            dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static void ggml_compute_forward_rope_back(
 | 
					static void ggml_compute_forward_rope_back(
 | 
				
			||||||
        const struct ggml_compute_params * params,
 | 
					        const struct ggml_compute_params * params,
 | 
				
			||||||
        const struct ggml_tensor * src0,
 | 
					        const struct ggml_tensor * src0,
 | 
				
			||||||
@@ -11512,11 +11327,11 @@ static void ggml_compute_forward_rope_back(
 | 
				
			|||||||
    switch (src0->type) {
 | 
					    switch (src0->type) {
 | 
				
			||||||
        case GGML_TYPE_F16:
 | 
					        case GGML_TYPE_F16:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
 | 
					                ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        case GGML_TYPE_F32:
 | 
					        case GGML_TYPE_F32:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
 | 
					                ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        default:
 | 
					        default:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
@@ -15559,17 +15374,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
 | 
				
			|||||||
                // necessary for llama
 | 
					                // necessary for llama
 | 
				
			||||||
                if (src0->grad) {
 | 
					                if (src0->grad) {
 | 
				
			||||||
                    //const int n_past = ((int32_t *) tensor->op_params)[0];
 | 
					                    //const int n_past = ((int32_t *) tensor->op_params)[0];
 | 
				
			||||||
                    const int n_dims = ((int32_t *) tensor->op_params)[1];
 | 
					                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
 | 
				
			||||||
                    const int mode   = ((int32_t *) tensor->op_params)[2];
 | 
					                    const int mode       = ((int32_t *) tensor->op_params)[2];
 | 
				
			||||||
                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
 | 
					                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
 | 
				
			||||||
                    float freq_base;
 | 
					                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
 | 
				
			||||||
                    float freq_scale;
 | 
					                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
 | 
				
			||||||
                    float xpos_base;
 | 
					
 | 
				
			||||||
                    bool  xpos_down;
 | 
					                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
 | 
				
			||||||
                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
 | 
					                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
 | 
				
			||||||
                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
 | 
					                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
 | 
				
			||||||
                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
 | 
					                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
 | 
				
			||||||
                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
 | 
					                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
 | 
				
			||||||
 | 
					                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
 | 
				
			||||||
 | 
					                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
 | 
				
			||||||
 | 
					                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    src0->grad = ggml_add_or_set(ctx,
 | 
					                    src0->grad = ggml_add_or_set(ctx,
 | 
				
			||||||
                            src0->grad,
 | 
					                            src0->grad,
 | 
				
			||||||
@@ -15579,8 +15397,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
 | 
				
			|||||||
                                n_dims,
 | 
					                                n_dims,
 | 
				
			||||||
                                mode,
 | 
					                                mode,
 | 
				
			||||||
                                n_ctx,
 | 
					                                n_ctx,
 | 
				
			||||||
 | 
					                                n_orig_ctx,
 | 
				
			||||||
                                freq_base,
 | 
					                                freq_base,
 | 
				
			||||||
                                freq_scale,
 | 
					                                freq_scale,
 | 
				
			||||||
 | 
					                                ext_factor,
 | 
				
			||||||
 | 
					                                attn_factor,
 | 
				
			||||||
 | 
					                                beta_fast,
 | 
				
			||||||
 | 
					                                beta_slow,
 | 
				
			||||||
                                xpos_base,
 | 
					                                xpos_base,
 | 
				
			||||||
                                xpos_down),
 | 
					                                xpos_down),
 | 
				
			||||||
                            zero_table);
 | 
					                            zero_table);
 | 
				
			||||||
@@ -15590,17 +15413,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
 | 
				
			|||||||
            {
 | 
					            {
 | 
				
			||||||
                if (src0->grad) {
 | 
					                if (src0->grad) {
 | 
				
			||||||
                    //const int n_past = ((int32_t *) tensor->op_params)[0];
 | 
					                    //const int n_past = ((int32_t *) tensor->op_params)[0];
 | 
				
			||||||
                    const int n_dims = ((int32_t *) tensor->op_params)[1];
 | 
					                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
 | 
				
			||||||
                    const int mode   = ((int32_t *) tensor->op_params)[2];
 | 
					                    const int mode       = ((int32_t *) tensor->op_params)[2];
 | 
				
			||||||
                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
 | 
					                    const int n_ctx      = ((int32_t *) tensor->op_params)[3];
 | 
				
			||||||
                    float freq_base;
 | 
					                    const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
 | 
				
			||||||
                    float freq_scale;
 | 
					                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
 | 
				
			||||||
                    float xpos_base;
 | 
					
 | 
				
			||||||
                    bool  xpos_down;
 | 
					                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
 | 
				
			||||||
                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
 | 
					                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
 | 
				
			||||||
                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
 | 
					                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
 | 
				
			||||||
                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
 | 
					                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
 | 
				
			||||||
                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
 | 
					                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
 | 
				
			||||||
 | 
					                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
 | 
				
			||||||
 | 
					                    memcpy(&xpos_base,   (int32_t *) tensor->op_params + 11, sizeof(float));
 | 
				
			||||||
 | 
					                    memcpy(&xpos_down,   (int32_t *) tensor->op_params + 12, sizeof(bool));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    src0->grad = ggml_add_or_set(ctx,
 | 
					                    src0->grad = ggml_add_or_set(ctx,
 | 
				
			||||||
                            src0->grad,
 | 
					                            src0->grad,
 | 
				
			||||||
@@ -15609,14 +15435,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
 | 
				
			|||||||
                                src1,
 | 
					                                src1,
 | 
				
			||||||
                                n_dims,
 | 
					                                n_dims,
 | 
				
			||||||
                                mode,
 | 
					                                mode,
 | 
				
			||||||
                                0,
 | 
					 | 
				
			||||||
                                n_ctx,
 | 
					                                n_ctx,
 | 
				
			||||||
 | 
					                                n_orig_ctx,
 | 
				
			||||||
                                freq_base,
 | 
					                                freq_base,
 | 
				
			||||||
                                freq_scale,
 | 
					                                freq_scale,
 | 
				
			||||||
                                0.0f,
 | 
					                                ext_factor,
 | 
				
			||||||
                                1.0f,
 | 
					                                attn_factor,
 | 
				
			||||||
                                0.0f,
 | 
					                                beta_fast,
 | 
				
			||||||
                                0.0f,
 | 
					                                beta_slow,
 | 
				
			||||||
                                xpos_base,
 | 
					                                xpos_base,
 | 
				
			||||||
                                xpos_down,
 | 
					                                xpos_down,
 | 
				
			||||||
                                false),
 | 
					                                false),
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										5
									
								
								ggml.h
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								ggml.h
									
									
									
									
									
								
							@@ -1372,8 +1372,13 @@ extern "C" {
 | 
				
			|||||||
            int                   n_dims,
 | 
					            int                   n_dims,
 | 
				
			||||||
            int                   mode,
 | 
					            int                   mode,
 | 
				
			||||||
            int                   n_ctx,
 | 
					            int                   n_ctx,
 | 
				
			||||||
 | 
					            int                   n_orig_ctx,
 | 
				
			||||||
            float                 freq_base,
 | 
					            float                 freq_base,
 | 
				
			||||||
            float                 freq_scale,
 | 
					            float                 freq_scale,
 | 
				
			||||||
 | 
					            float                 ext_factor,
 | 
				
			||||||
 | 
					            float                 attn_factor,
 | 
				
			||||||
 | 
					            float                 beta_fast,
 | 
				
			||||||
 | 
					            float                 beta_slow,
 | 
				
			||||||
            float                 xpos_base,
 | 
					            float                 xpos_base,
 | 
				
			||||||
            bool                  xpos_down);
 | 
					            bool                  xpos_down);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user