mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Massive improvement for TG for fp16
This commit is contained in:
		| @@ -534,14 +534,27 @@ kernel void kernel_mul_mat_f16_f32_1row( | |||||||
|     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); | ||||||
|  |  | ||||||
|     float sumf = 0; |     float sumf = 0; | ||||||
|     for (int i = tiisg; i < ne00; i += 32) { |     if (ne00 < 128) { | ||||||
|         sumf += (float) x[i] * (float) y[i]; |         for (int i = tiisg; i < ne00; i += 32) { | ||||||
|  |             sumf += (float) x[i] * (float) y[i]; | ||||||
|  |         } | ||||||
|  |         float all_sum = simd_sum(sumf); | ||||||
|  |         if (tiisg == 0) { | ||||||
|  |             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||||
|  |         } | ||||||
|  |     } else { | ||||||
|  |         device const half4  * x4 = (device const half4  *) x; | ||||||
|  |         device const float4 * y4 = (device const float4 *) y; | ||||||
|  |         for (int i = tiisg; i < ne00/4; i += 32) { | ||||||
|  |             for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; | ||||||
|  |         } | ||||||
|  |         float all_sum = simd_sum(sumf); | ||||||
|  |         if (tiisg == 0) { | ||||||
|  |             for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; | ||||||
|  |             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     float all_sum = simd_sum(sumf); |  | ||||||
|     if (tiisg == 0) { |  | ||||||
|         dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| #define N_F16_F32 4 | #define N_F16_F32 4 | ||||||
| @@ -573,22 +586,46 @@ kernel void kernel_mul_mat_f16_f32( | |||||||
|  |  | ||||||
|     device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); |     device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); | ||||||
|  |  | ||||||
|     for (int row = 0; row < N_F16_F32; ++row) { |     if (ne00 < 128) { | ||||||
|         int r1 = rb + row; |         for (int row = 0; row < N_F16_F32; ++row) { | ||||||
|         if (r1 >= ne11) { |             int r1 = rb + row; | ||||||
|             break; |             if (r1 >= ne11) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); | ||||||
|  |  | ||||||
|  |             float sumf = 0; | ||||||
|  |             for (int i = tiisg; i < ne00; i += 32) { | ||||||
|  |                 sumf += (float) x[i] * (float) y[i]; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             float all_sum = simd_sum(sumf); | ||||||
|  |             if (tiisg == 0) { | ||||||
|  |                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |     } else { | ||||||
|  |         device const half4 * x4 = (device const half4 *)x; | ||||||
|  |         for (int row = 0; row < N_F16_F32; ++row) { | ||||||
|  |             int r1 = rb + row; | ||||||
|  |             if (r1 >= ne11) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |  | ||||||
|         device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |             device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12); | ||||||
|  |             device const float4 * y4 = (device const float4 *) y; | ||||||
|  |  | ||||||
|         float sumf = 0; |             float sumf = 0; | ||||||
|         for (int i = tiisg; i < ne00; i += 32) { |             for (int i = tiisg; i < ne00/4; i += 32) { | ||||||
|             sumf += (float) x[i] * (float) y[i]; |                 for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; | ||||||
|         } |             } | ||||||
|  |  | ||||||
|         float all_sum = simd_sum(sumf); |             float all_sum = simd_sum(sumf); | ||||||
|         if (tiisg == 0) { |             if (tiisg == 0) { | ||||||
|             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |                 for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; | ||||||
|  |                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Iwan Kawrakow
					Iwan Kawrakow