mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	| @@ -1001,11 +1001,15 @@ void ggml_metal_graph_compute( | ||||
|                         } break; | ||||
|                     case GGML_OP_SOFT_MAX: | ||||
|                         { | ||||
|                             const int nth = MIN(32, ne00); | ||||
|                             int nth = 32; // SIMD width | ||||
|  | ||||
|                             if (ne00%4 == 0) { | ||||
|                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; | ||||
|                             } else { | ||||
|                                 do { | ||||
|                                     nth *= 2; | ||||
|                                 } while (nth <= ne00 && nth <= 1024); | ||||
|                                 nth /= 2; | ||||
|                                 [encoder setComputePipelineState:ctx->pipeline_soft_max]; | ||||
|                             } | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
| @@ -1013,8 +1017,9 @@ void ggml_metal_graph_compute( | ||||
|                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; | ||||
|                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; | ||||
|                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; | ||||
|                             [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; | ||||
|  | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||
|                         } break; | ||||
|                     case GGML_OP_DIAG_MASK_INF: | ||||
|                         { | ||||
|   | ||||
							
								
								
									
										129
									
								
								ggml-metal.metal
									
									
									
									
									
								
							
							
						
						
									
										129
									
								
								ggml-metal.metal
									
									
									
									
									
								
							| @@ -184,36 +184,73 @@ kernel void kernel_soft_max( | ||||
|         constant   int64_t & ne00, | ||||
|         constant   int64_t & ne01, | ||||
|         constant   int64_t & ne02, | ||||
|         uint3 tgpig[[threadgroup_position_in_grid]], | ||||
|         uint3 tpitg[[thread_position_in_threadgroup]], | ||||
|         uint3   ntg[[threads_per_threadgroup]]) { | ||||
|     const int64_t i03 = tgpig[2]; | ||||
|     const int64_t i02 = tgpig[1]; | ||||
|     const int64_t i01 = tgpig[0]; | ||||
|         threadgroup float  * buf [[threadgroup(0)]], | ||||
|         uint  tgpig[[threadgroup_position_in_grid]], | ||||
|         uint  tpitg[[thread_position_in_threadgroup]], | ||||
|         uint  sgitg[[simdgroup_index_in_threadgroup]], | ||||
|         uint  tiisg[[thread_index_in_simdgroup]], | ||||
|         uint    ntg[[threads_per_threadgroup]]) { | ||||
|     const int64_t i03 = (tgpig) / (ne02*ne01); | ||||
|     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; | ||||
|     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); | ||||
|  | ||||
|     device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; | ||||
|     device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; | ||||
|  | ||||
|     // parallel max | ||||
|     float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; | ||||
|     for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { | ||||
|     float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY; | ||||
|  | ||||
|     for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) { | ||||
|         lmax = MAX(lmax, psrc0[i00]); | ||||
|     } | ||||
|     const float max = simd_max(lmax); | ||||
|  | ||||
|     float max = simd_max(lmax); | ||||
|     if (tiisg == 0) { | ||||
|         buf[sgitg] = max; | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     // broadcast, simd group number is ntg / 32 | ||||
|     for (uint i = ntg / 32 / 2; i > 0; i /= 2) { | ||||
|        if (tpitg < i) { | ||||
|            buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); | ||||
|        } | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     max = buf[0]; | ||||
|  | ||||
|     // parallel sum | ||||
|     float lsum = 0.0f; | ||||
|     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { | ||||
|     for (int i00 = tpitg; i00 < ne00; i00 += ntg) { | ||||
|         const float exp_psrc0 = exp(psrc0[i00] - max); | ||||
|         lsum += exp_psrc0; | ||||
|         // Remember the result of exp here. exp is expensive, so we really do not | ||||
|         // whish to compute it twice. | ||||
|         // wish to compute it twice. | ||||
|         pdst[i00] = exp_psrc0; | ||||
|     } | ||||
|  | ||||
|     const float sum = simd_sum(lsum); | ||||
|     float sum = simd_sum(lsum); | ||||
|     if (tiisg == 0) { | ||||
|         buf[sgitg] = sum; | ||||
|     } | ||||
|  | ||||
|     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     // broadcast, simd group number is ntg / 32 | ||||
|     for (uint i = ntg / 32 / 2; i > 0; i /= 2) { | ||||
|        if (tpitg < i) { | ||||
|            buf[tpitg] += buf[tpitg + i]; | ||||
|        } | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     sum = buf[0]; | ||||
|  | ||||
|     for (int i00 = tpitg; i00 < ne00; i00 += ntg) { | ||||
|         pdst[i00] /= sum; | ||||
|     } | ||||
| } | ||||
| @@ -224,37 +261,73 @@ kernel void kernel_soft_max_4( | ||||
|         constant   int64_t & ne00, | ||||
|         constant   int64_t & ne01, | ||||
|         constant   int64_t & ne02, | ||||
|         uint3 tgpig[[threadgroup_position_in_grid]], | ||||
|         uint3 tpitg[[thread_position_in_threadgroup]], | ||||
|         uint3   ntg[[threads_per_threadgroup]]) { | ||||
|     const int64_t i03 = tgpig[2]; | ||||
|     const int64_t i02 = tgpig[1]; | ||||
|     const int64_t i01 = tgpig[0]; | ||||
|         threadgroup float  * buf [[threadgroup(0)]], | ||||
|         uint  tgpig[[threadgroup_position_in_grid]], | ||||
|         uint  tpitg[[thread_position_in_threadgroup]], | ||||
|         uint  sgitg[[simdgroup_index_in_threadgroup]], | ||||
|         uint  tiisg[[thread_index_in_simdgroup]], | ||||
|         uint    ntg[[threads_per_threadgroup]]) { | ||||
|     const int64_t i03 = (tgpig) / (ne02*ne01); | ||||
|     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; | ||||
|     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); | ||||
|  | ||||
|     device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); | ||||
|     device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); | ||||
|  | ||||
|     // parallel max | ||||
|     float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; | ||||
|     for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { | ||||
|     float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY; | ||||
|  | ||||
|     for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) { | ||||
|         lmax4 = fmax(lmax4, psrc4[i00]); | ||||
|     } | ||||
|     float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); | ||||
|  | ||||
|     const float max = simd_max(lmax); | ||||
|     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); | ||||
|     float max = simd_max(lmax); | ||||
|     if (tiisg == 0) { | ||||
|         buf[sgitg] = max; | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     // broadcast, simd group number is ntg / 32 | ||||
|     for (uint i = ntg / 32 / 2; i > 0; i /= 2) { | ||||
|        if (tpitg < i) { | ||||
|            buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); | ||||
|        } | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     max = buf[0]; | ||||
|  | ||||
|     // parallel sum | ||||
|     float4 lsum4 = 0.0f; | ||||
|     for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { | ||||
|     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { | ||||
|         const float4 exp_psrc4 = exp(psrc4[i00] - max); | ||||
|         lsum4 += exp_psrc4; | ||||
|         pdst4[i00] = exp_psrc4; | ||||
|     } | ||||
|     float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; | ||||
|  | ||||
|     const float sum = simd_sum(lsum); | ||||
|     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; | ||||
|     float sum = simd_sum(lsum); | ||||
|     if (tiisg == 0) { | ||||
|         buf[sgitg] = sum; | ||||
|     } | ||||
|  | ||||
|     for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     // broadcast, simd group number is ntg / 32 | ||||
|     for (uint i = ntg / 32 / 2; i > 0; i /= 2) { | ||||
|        if (tpitg < i) { | ||||
|            buf[tpitg] += buf[tpitg + i]; | ||||
|        } | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     sum = buf[0]; | ||||
|  | ||||
|     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { | ||||
|         pdst4[i00] /= sum; | ||||
|     } | ||||
| } | ||||
| @@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf( | ||||
|         dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; | ||||
|     } else { | ||||
|         dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; | ||||
|      } | ||||
|     } | ||||
| } | ||||
|  | ||||
| kernel void kernel_diag_mask_inf_8( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov