mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	metal : batch rows copy in a single threadgroup (#14384)
* metal : batch rows copy in a single threadgroup ggml-ci * metal : handle some edge cases when threadgroup size is not a power of 2 ggml-ci
This commit is contained in:
		@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
                    nth *= 2;
 | 
					                    nth *= 2;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
 | 
				
			||||||
                nth = MIN(nth, ne00);
 | 
					                nth = MIN(nth, ne00);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ggml_metal_kargs_sum_rows args = {
 | 
					                ggml_metal_kargs_sum_rows args = {
 | 
				
			||||||
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
                    nth *= 2;
 | 
					                    nth *= 2;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
 | 
				
			||||||
                nth = MIN(nth, ne00/4);
 | 
					                nth = MIN(nth, ne00/4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ggml_metal_kargs_rms_norm args = {
 | 
					                ggml_metal_kargs_rms_norm args = {
 | 
				
			||||||
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
                    nth *= 2;
 | 
					                    nth *= 2;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
 | 
				
			||||||
                nth = MIN(nth, ne00/4);
 | 
					                nth = MIN(nth, ne00/4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ggml_metal_kargs_l2_norm args = {
 | 
					                ggml_metal_kargs_l2_norm args = {
 | 
				
			||||||
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
                    nth *= 2;
 | 
					                    nth *= 2;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
 | 
				
			||||||
                nth = MIN(nth, ne00/4);
 | 
					                nth = MIN(nth, ne00/4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ggml_metal_kargs_norm args = {
 | 
					                ggml_metal_kargs_norm args = {
 | 
				
			||||||
@@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
                    default: GGML_ABORT("not implemented");
 | 
					                    default: GGML_ABORT("not implemented");
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // TODO: support
 | 
				
			||||||
 | 
					                //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
 | 
				
			||||||
 | 
					                const int32_t nk00 = ne00;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                int nth = 32; // SIMD width
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
 | 
				
			||||||
 | 
					                    nth *= 2;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // when rows are small, we can batch them together in a single threadgroup
 | 
				
			||||||
 | 
					                int nrptg = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // TODO: relax this constraint in the future
 | 
				
			||||||
 | 
					                if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
 | 
				
			||||||
 | 
					                    if (nth > nk00) {
 | 
				
			||||||
 | 
					                        nrptg = (nth + nk00 - 1)/nk00;
 | 
				
			||||||
 | 
					                        nth   = nk00;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
 | 
				
			||||||
 | 
					                            nrptg--;
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                nth = MIN(nth, nk00);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ggml_metal_kargs_cpy args = {
 | 
					                ggml_metal_kargs_cpy args = {
 | 
				
			||||||
                    /*.ne00 =*/ ne00,
 | 
					                    /*.ne00 =*/ nk00,
 | 
				
			||||||
                    /*.ne01 =*/ ne01,
 | 
					                    /*.ne01 =*/ ne01,
 | 
				
			||||||
                    /*.ne02 =*/ ne02,
 | 
					                    /*.ne02 =*/ ne02,
 | 
				
			||||||
                    /*.ne03 =*/ ne03,
 | 
					                    /*.ne03 =*/ ne03,
 | 
				
			||||||
@@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
 | 
					                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
 | 
				
			||||||
                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 | 
					                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
 | 
					                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
 | 
				
			||||||
                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        case GGML_OP_SET:
 | 
					        case GGML_OP_SET:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4306,11 +4306,16 @@ kernel void kernel_cpy(
 | 
				
			|||||||
        device  const char * src0,
 | 
					        device  const char * src0,
 | 
				
			||||||
        device        char * dst,
 | 
					        device        char * dst,
 | 
				
			||||||
        uint3   tgpig[[threadgroup_position_in_grid]],
 | 
					        uint3   tgpig[[threadgroup_position_in_grid]],
 | 
				
			||||||
 | 
					        uint    tiitg[[thread_index_in_threadgroup]],
 | 
				
			||||||
        ushort3 tpitg[[thread_position_in_threadgroup]],
 | 
					        ushort3 tpitg[[thread_position_in_threadgroup]],
 | 
				
			||||||
        ushort3   ntg[[threads_per_threadgroup]]) {
 | 
					        ushort3  tptg[[threads_per_threadgroup]]) {
 | 
				
			||||||
    const int i03 = tgpig[2];
 | 
					    const int i03 = tgpig[2];
 | 
				
			||||||
    const int i02 = tgpig[1];
 | 
					    const int i02 = tgpig[1];
 | 
				
			||||||
    const int i01 = tgpig[0];
 | 
					    const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (i01 >= args.ne01) {
 | 
				
			||||||
 | 
					        return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 | 
					    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4321,7 +4326,7 @@ kernel void kernel_cpy(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 | 
					    device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
 | 
					    for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
 | 
				
			||||||
        device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 | 
					        device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 | 
				
			||||||
        dst_data[i00] = (T1) src[0];
 | 
					        dst_data[i00] = (T1) src[0];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user