mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	ggml : rms_norm in chunks
This commit is contained in:
		
							
								
								
									
										33
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								ggml.c
									
									
									
									
									
								
							@@ -9033,18 +9033,20 @@ static void ggml_compute_forward_rms_norm_f32(
 | 
				
			|||||||
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 | 
					    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
					    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
				
			||||||
 | 
					        atomic_store(params->aic, 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return;
 | 
					        return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    GGML_ASSERT(src0->nb[0] == sizeof(float));
 | 
					    GGML_ASSERT(src0->nb[0] == sizeof(float));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int ith = params->ith;
 | 
					    const int ith = params->ith; UNUSED(ith);
 | 
				
			||||||
    const int nth = params->nth;
 | 
					    const int nth = params->nth;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t ne00 = src0->ne[0];
 | 
					    const int64_t ne00 = src0->ne[0];
 | 
				
			||||||
    const int64_t ne01 = src0->ne[1];
 | 
					    const int64_t ne01 = src0->ne[1];
 | 
				
			||||||
    const int64_t ne02 = src0->ne[2];
 | 
					    const int64_t ne02 = src0->ne[2];
 | 
				
			||||||
    const int64_t ne03 = src0->ne[3];
 | 
					    const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const size_t nb01 = src0->nb[1];
 | 
					    const size_t nb01 = src0->nb[1];
 | 
				
			||||||
    const size_t nb02 = src0->nb[2];
 | 
					    const size_t nb02 = src0->nb[2];
 | 
				
			||||||
@@ -9056,10 +9058,22 @@ static void ggml_compute_forward_rms_norm_f32(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    const float eps = 1e-6f; // TODO: make this a parameter
 | 
					    const float eps = 1e-6f; // TODO: make this a parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: optimize
 | 
					    const int nr = ggml_nrows(src0);
 | 
				
			||||||
    for (int64_t i03 = 0; i03 < ne03; i03++) {
 | 
					    const int dr = (nr + 8*nth - 1)/(8*nth);
 | 
				
			||||||
        for (int64_t i02 = 0; i02 < ne02; i02++) {
 | 
					
 | 
				
			||||||
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
 | 
					    while (true) {
 | 
				
			||||||
 | 
					        const int ir0 = atomic_fetch_add(params->aic, dr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (int ir = ir0; ir < ir0 + dr; ++ir) {
 | 
				
			||||||
 | 
					            if (ir >= nr) {
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // src0 indices
 | 
				
			||||||
 | 
					            const int i03 = ir/(ne02*ne01);
 | 
				
			||||||
 | 
					            const int i02 = (ir - i03*ne02*ne01)/ne01;
 | 
				
			||||||
 | 
					            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 | 
					            const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ggml_float sum = 0.0;
 | 
					            ggml_float sum = 0.0;
 | 
				
			||||||
@@ -9080,6 +9094,9 @@ static void ggml_compute_forward_rms_norm_f32(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            ggml_vec_scale_f32(ne00, y, scale);
 | 
					            ggml_vec_scale_f32(ne00, y, scale);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (ir0 + dr >= nr) {
 | 
				
			||||||
 | 
					            break;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -9754,11 +9771,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
 | 
				
			|||||||
    const int nb2  = dst->nb[2];
 | 
					    const int nb2  = dst->nb[2];
 | 
				
			||||||
    const int nb3  = dst->nb[3];
 | 
					    const int nb3  = dst->nb[3];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int ith = params->ith;
 | 
					    const int ith = params->ith; UNUSED(ith);
 | 
				
			||||||
    const int nth = params->nth;
 | 
					    const int nth = params->nth;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    UNUSED(ith);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    GGML_ASSERT(ne02 == ne12);
 | 
					    GGML_ASSERT(ne02 == ne12);
 | 
				
			||||||
    GGML_ASSERT(ne03 == ne13);
 | 
					    GGML_ASSERT(ne03 == ne13);
 | 
				
			||||||
    GGML_ASSERT(ne2  == ne12);
 | 
					    GGML_ASSERT(ne2  == ne12);
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user