mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	
							
								
								
									
										128
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										128
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -2069,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { | |||||||
|     "GELU", |     "GELU", | ||||||
|     "SILU", |     "SILU", | ||||||
|     "NORM", |     "NORM", | ||||||
|  |     "RMS_NORM", | ||||||
|  |  | ||||||
|     "MUL_MAT", |     "MUL_MAT", | ||||||
|  |  | ||||||
| @@ -2089,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { | |||||||
|     "FLASH_FF", |     "FLASH_FF", | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); | static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); | ||||||
|  |  | ||||||
| static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { | static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { | ||||||
|     "none", |     "none", | ||||||
| @@ -2112,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { | |||||||
|     "gelu(x)", |     "gelu(x)", | ||||||
|     "silu(x)", |     "silu(x)", | ||||||
|     "norm(x)", |     "norm(x)", | ||||||
|  |     "rms_norm(x)", | ||||||
|  |  | ||||||
|     "X*Y", |     "X*Y", | ||||||
|  |  | ||||||
| @@ -2132,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { | |||||||
|     "flash_ff(x)", |     "flash_ff(x)", | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); | static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); | ||||||
|  |  | ||||||
| // | // | ||||||
| // ggml object | // ggml object | ||||||
| @@ -3618,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace( | |||||||
|     return ggml_norm_impl(ctx, a, true); |     return ggml_norm_impl(ctx, a, true); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | struct ggml_tensor * ggml_rms_norm_impl( | ||||||
|  |         struct ggml_context * ctx, | ||||||
|  |         struct ggml_tensor  * a, | ||||||
|  |         bool inplace) { | ||||||
|  |     bool is_node = false; | ||||||
|  |  | ||||||
|  |     if (!inplace && (a->grad)) { | ||||||
|  |         GGML_ASSERT(false); // TODO: implement backward | ||||||
|  |         is_node = true; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); | ||||||
|  |  | ||||||
|  |     result->op   = GGML_OP_RMS_NORM; | ||||||
|  |     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; | ||||||
|  |     result->src0 = a; | ||||||
|  |     result->src1 = NULL; // TODO: maybe store epsilon here? | ||||||
|  |  | ||||||
|  |     return result; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct ggml_tensor * ggml_rms_norm( | ||||||
|  |         struct ggml_context * ctx, | ||||||
|  |         struct ggml_tensor  * a) { | ||||||
|  |     return ggml_rms_norm_impl(ctx, a, false); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct ggml_tensor * ggml_rms_norm_inplace( | ||||||
|  |         struct ggml_context * ctx, | ||||||
|  |         struct ggml_tensor  * a) { | ||||||
|  |     return ggml_rms_norm_impl(ctx, a, true); | ||||||
|  | } | ||||||
|  |  | ||||||
| // ggml_mul_mat | // ggml_mul_mat | ||||||
|  |  | ||||||
| struct ggml_tensor * ggml_mul_mat( | struct ggml_tensor * ggml_mul_mat( | ||||||
| @@ -5406,6 +5441,87 @@ static void ggml_compute_forward_norm( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void ggml_compute_forward_rms_norm_f32( | ||||||
|  |         const struct ggml_compute_params * params, | ||||||
|  |         const struct ggml_tensor * src0, | ||||||
|  |         struct ggml_tensor * dst) { | ||||||
|  |     GGML_ASSERT(ggml_are_same_shape(src0, dst)); | ||||||
|  |  | ||||||
|  |     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||||||
|  |  | ||||||
|  |     const int ith = params->ith; | ||||||
|  |     const int nth = params->nth; | ||||||
|  |  | ||||||
|  |     const int ne00 = src0->ne[0]; | ||||||
|  |     const int ne01 = src0->ne[1]; | ||||||
|  |     const int ne02 = src0->ne[2]; | ||||||
|  |     const int ne03 = src0->ne[3]; | ||||||
|  |  | ||||||
|  |     const size_t nb01 = src0->nb[1]; | ||||||
|  |     const size_t nb02 = src0->nb[2]; | ||||||
|  |     const size_t nb03 = src0->nb[3]; | ||||||
|  |  | ||||||
|  |     const size_t nb1 = dst->nb[1]; | ||||||
|  |     const size_t nb2 = dst->nb[2]; | ||||||
|  |     const size_t nb3 = dst->nb[3]; | ||||||
|  |  | ||||||
|  |     const ggml_float eps = 1e-5f; // TODO: make this a parameter | ||||||
|  |  | ||||||
|  |     // TODO: optimize | ||||||
|  |     for (int i03 = 0; i03 < ne03; i03++) { | ||||||
|  |         for (int i02 = 0; i02 < ne02; i02++) { | ||||||
|  |             for (int i01 = ith; i01 < ne01; i01 += nth) { | ||||||
|  |                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); | ||||||
|  |  | ||||||
|  |                 ggml_float mean = 0.0; | ||||||
|  |                 for (int i00 = 0; i00 < ne00; i00++) { | ||||||
|  |                     mean += x[i00] * x[i00]; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 mean /= ne00; | ||||||
|  |  | ||||||
|  |                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); | ||||||
|  |                  | ||||||
|  |                 memcpy(y, x, ne00 * sizeof(float)); | ||||||
|  |                 // for (int i00 = 0; i00 < ne00; i00++) { | ||||||
|  |                 //     y[i00] = x[i00]; | ||||||
|  |                 // } | ||||||
|  |  | ||||||
|  |                 const float scale = 1.0/sqrt(mean + eps); | ||||||
|  |  | ||||||
|  |                 ggml_vec_scale_f32(ne00, y, scale); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void ggml_compute_forward_rms_norm( | ||||||
|  |         const struct ggml_compute_params * params, | ||||||
|  |         const struct ggml_tensor * src0, | ||||||
|  |         struct ggml_tensor * dst) { | ||||||
|  |     switch (src0->type) { | ||||||
|  |         case GGML_TYPE_F32: | ||||||
|  |             { | ||||||
|  |                 ggml_compute_forward_rms_norm_f32(params, src0, dst); | ||||||
|  |             } break; | ||||||
|  |         case GGML_TYPE_Q4_0: | ||||||
|  |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_I8: | ||||||
|  |         case GGML_TYPE_I16: | ||||||
|  |         case GGML_TYPE_I32: | ||||||
|  |         case GGML_TYPE_F16: | ||||||
|  |         case GGML_TYPE_COUNT: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } break; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| // ggml_compute_forward_mul_mat | // ggml_compute_forward_mul_mat | ||||||
|  |  | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | ||||||
| @@ -8522,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm | |||||||
|             { |             { | ||||||
|                 ggml_compute_forward_norm(params, tensor->src0, tensor); |                 ggml_compute_forward_norm(params, tensor->src0, tensor); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_OP_RMS_NORM: | ||||||
|  |             { | ||||||
|  |                 ggml_compute_forward_rms_norm(params, tensor->src0, tensor); | ||||||
|  |             } break; | ||||||
|         case GGML_OP_MUL_MAT: |         case GGML_OP_MUL_MAT: | ||||||
|             { |             { | ||||||
|                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); |                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); | ||||||
| @@ -8764,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor | |||||||
|             { |             { | ||||||
|                 GGML_ASSERT(false); // TODO: not implemented |                 GGML_ASSERT(false); // TODO: not implemented | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_OP_RMS_NORM: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); // TODO: not implemented | ||||||
|  |             } break; | ||||||
|         case GGML_OP_MUL_MAT: |         case GGML_OP_MUL_MAT: | ||||||
|             { |             { | ||||||
|                 if (src0->grad) { |                 if (src0->grad) { | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								ggml.h
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								ggml.h
									
									
									
									
									
								
							| @@ -230,6 +230,7 @@ enum ggml_op { | |||||||
|     GGML_OP_GELU, |     GGML_OP_GELU, | ||||||
|     GGML_OP_SILU, |     GGML_OP_SILU, | ||||||
|     GGML_OP_NORM, // normalize |     GGML_OP_NORM, // normalize | ||||||
|  |     GGML_OP_RMS_NORM, | ||||||
|  |  | ||||||
|     GGML_OP_MUL_MAT, |     GGML_OP_MUL_MAT, | ||||||
|  |  | ||||||
| @@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm( | |||||||
|         struct ggml_context * ctx, |         struct ggml_context * ctx, | ||||||
|         struct ggml_tensor  * a); |         struct ggml_tensor  * a); | ||||||
|  |  | ||||||
|  | struct ggml_tensor * ggml_rms_norm( | ||||||
|  |         struct ggml_context * ctx, | ||||||
|  |         struct ggml_tensor  * a); | ||||||
|  |  | ||||||
| // A: m rows, n columns | // A: m rows, n columns | ||||||
| // B: p rows, n columns (i.e. we transpose it internally) | // B: p rows, n columns (i.e. we transpose it internally) | ||||||
| // result is m columns, p rows | // result is m columns, p rows | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								main.cpp
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.cpp
									
									
									
									
									
								
							| @@ -588,7 +588,7 @@ bool llama_eval( | |||||||
|  |  | ||||||
|         // norm |         // norm | ||||||
|         { |         { | ||||||
|             cur = ggml_norm(ctx0, inpL); |             cur = ggml_rms_norm(ctx0, inpL); | ||||||
|  |  | ||||||
|             // cur = attention_norm*cur |             // cur = attention_norm*cur | ||||||
|             cur = ggml_mul(ctx0, |             cur = ggml_mul(ctx0, | ||||||
| @@ -678,7 +678,7 @@ bool llama_eval( | |||||||
|         { |         { | ||||||
|             // norm |             // norm | ||||||
|             { |             { | ||||||
|                 cur = ggml_norm(ctx0, inpFF); |                 cur = ggml_rms_norm(ctx0, inpFF); | ||||||
|  |  | ||||||
|                 // cur = ffn_norm*cur |                 // cur = ffn_norm*cur | ||||||
|                 cur = ggml_mul(ctx0, |                 cur = ggml_mul(ctx0, | ||||||
| @@ -713,7 +713,7 @@ bool llama_eval( | |||||||
|  |  | ||||||
|     // norm |     // norm | ||||||
|     { |     { | ||||||
|         inpL = ggml_norm(ctx0, inpL); |         inpL = ggml_rms_norm(ctx0, inpL); | ||||||
|  |  | ||||||
|         // inpL = norm*inpL |         // inpL = norm*inpL | ||||||
|         inpL = ggml_mul(ctx0, |         inpL = ggml_mul(ctx0, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 hoangmit
					hoangmit