mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : simplify soft max kernel
ggml-ci
This commit is contained in:
		| @@ -203,9 +203,9 @@ kernel void kernel_soft_max( | |||||||
|     device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; |     device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; | ||||||
|  |  | ||||||
|     // parallel max |     // parallel max | ||||||
|     float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; |     float lmax = -INFINITY; | ||||||
|  |  | ||||||
|     for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) { |     for (int i00 = tpitg; i00 < ne00; i00 += ntg) { | ||||||
|         lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); |         lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -284,9 +284,9 @@ kernel void kernel_soft_max_4( | |||||||
|     device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); |     device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); | ||||||
|  |  | ||||||
|     // parallel max |     // parallel max | ||||||
|     float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; |     float4 lmax4 = -INFINITY; | ||||||
|  |  | ||||||
|     for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) { |     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { | ||||||
|         lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); |         lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov