mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	metal : warp-based reduce for rms_norm
This commit is contained in:
		@@ -1358,7 +1358,11 @@ void ggml_metal_graph_compute(
 | 
				
			|||||||
                            float eps;
 | 
					                            float eps;
 | 
				
			||||||
                            memcpy(&eps, dst->op_params, sizeof(float));
 | 
					                            memcpy(&eps, dst->op_params, sizeof(float));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            const int nth = MIN(512, ne00);
 | 
					                            int nth = 32; // SIMD width
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            while (nth < ne00/4 && nth < 1024) {
 | 
				
			||||||
 | 
					                                nth *= 2;
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            [encoder setComputePipelineState:ctx->pipeline_rms_norm];
 | 
					                            [encoder setComputePipelineState:ctx->pipeline_rms_norm];
 | 
				
			||||||
                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
 | 
					                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
 | 
				
			||||||
@@ -1366,7 +1370,7 @@ void ggml_metal_graph_compute(
 | 
				
			|||||||
                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
 | 
					                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
 | 
				
			||||||
                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
 | 
					                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
 | 
				
			||||||
                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
 | 
					                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
 | 
				
			||||||
                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 | 
					                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            const int64_t nrows = ggml_nrows(src0);
 | 
					                            const int64_t nrows = ggml_nrows(src0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -447,14 +447,13 @@ kernel void kernel_rms_norm(
 | 
				
			|||||||
        constant   int64_t & ne00,
 | 
					        constant   int64_t & ne00,
 | 
				
			||||||
        constant  uint64_t & nb01,
 | 
					        constant  uint64_t & nb01,
 | 
				
			||||||
        constant     float & eps,
 | 
					        constant     float & eps,
 | 
				
			||||||
        threadgroup float  * sum [[threadgroup(0)]],
 | 
					        threadgroup float  * buf [[threadgroup(0)]],
 | 
				
			||||||
        uint tgpig[[threadgroup_position_in_grid]],
 | 
					        uint tgpig[[threadgroup_position_in_grid]],
 | 
				
			||||||
        uint tpitg[[thread_position_in_threadgroup]],
 | 
					        uint tpitg[[thread_position_in_threadgroup]],
 | 
				
			||||||
        uint sgitg[[simdgroup_index_in_threadgroup]],
 | 
					        uint sgitg[[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
        uint tiisg[[thread_index_in_simdgroup]],
 | 
					        uint tiisg[[thread_index_in_simdgroup]],
 | 
				
			||||||
        uint   ntg[[threads_per_threadgroup]]) {
 | 
					        uint   ntg[[threads_per_threadgroup]]) {
 | 
				
			||||||
    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
 | 
					    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
 | 
				
			||||||
    device const float  * x_scalar = (device const float  *) x;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    float4 sumf = 0;
 | 
					    float4 sumf = 0;
 | 
				
			||||||
    float all_sum = 0;
 | 
					    float all_sum = 0;
 | 
				
			||||||
@@ -465,40 +464,30 @@ kernel void kernel_rms_norm(
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
 | 
					    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
 | 
				
			||||||
    all_sum = simd_sum(all_sum);
 | 
					    all_sum = simd_sum(all_sum);
 | 
				
			||||||
 | 
					    if (ntg > N_SIMDWIDTH) {
 | 
				
			||||||
 | 
					        if (sgitg == 0) {
 | 
				
			||||||
 | 
					            buf[tiisg] = 0.0f;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (tiisg == 0) {
 | 
					        if (tiisg == 0) {
 | 
				
			||||||
        sum[sgitg] = all_sum;
 | 
					            buf[sgitg] = all_sum;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
					        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // broadcast, simd group number is ntg / 32
 | 
					        all_sum = buf[tiisg];
 | 
				
			||||||
    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
 | 
					        all_sum = simd_sum(all_sum);
 | 
				
			||||||
       if (tpitg < i) {
 | 
					 | 
				
			||||||
           sum[tpitg] += sum[tpitg + i];
 | 
					 | 
				
			||||||
       }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    if (tpitg == 0) {
 | 
					 | 
				
			||||||
        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
 | 
					 | 
				
			||||||
            sum[0] += x_scalar[i];
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        sum[0] /= ne00;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
					    const float mean  = all_sum/ne00;
 | 
				
			||||||
 | 
					 | 
				
			||||||
    const float mean  = sum[0];
 | 
					 | 
				
			||||||
    const float scale = 1.0f/sqrt(mean + eps);
 | 
					    const float scale = 1.0f/sqrt(mean + eps);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    device float4 * y = (device float4 *) (dst + tgpig*ne00);
 | 
					    device float4 * y = (device float4 *) (dst + tgpig*ne00);
 | 
				
			||||||
    device float * y_scalar = (device float *) y;
 | 
					 | 
				
			||||||
    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
 | 
					    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
 | 
				
			||||||
        y[i00] = x[i00] * scale;
 | 
					        y[i00] = x[i00] * scale;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (tpitg == 0) {
 | 
					 | 
				
			||||||
        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
 | 
					 | 
				
			||||||
            y_scalar[i00] = x_scalar[i00] * scale;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
 | 
					// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user