mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	ggml: add support for float16 input tensors in pooling operations (ggml/895)
* Add support for float16 tensors in 1d pooling operations * Add support for float16 input tensors in 2d pooling operations * code cleanup remove unnecessary casting during srow ptr initialization --------- Co-authored-by: vanaka11 <vanaka1189@gmail.com>
This commit is contained in:
		 Ivan Filipov
					Ivan Filipov
				
			
				
					committed by
					
						 Georgi Gerganov
						Georgi Gerganov
					
				
			
			
				
	
			
			
			 Georgi Gerganov
						Georgi Gerganov
					
				
			
						parent
						
							203b7f1531
						
					
				
				
					commit
					9f77d899b7
				
			| @@ -14746,7 +14746,7 @@ static void ggml_compute_forward_pool_1d_sk_p0( | ||||
|  | ||||
|     const struct ggml_tensor * src = dst->src[0]; | ||||
|  | ||||
|     assert(src->type == GGML_TYPE_F32); | ||||
|     assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); | ||||
|  | ||||
|     if (params->ith != 0) { | ||||
|         return; | ||||
| @@ -14759,10 +14759,8 @@ static void ggml_compute_forward_pool_1d_sk_p0( | ||||
|     const int64_t rs = dst->ne[0]; | ||||
|  | ||||
|     while (cdata < data_end) { | ||||
|         const float * const srow = (const float *)cdata; | ||||
|  | ||||
|         const void * srow = (const void *)cdata; | ||||
|         int j = 0; | ||||
|  | ||||
|         for (int64_t i = 0; i < rs; ++i) { | ||||
|             switch (op) { | ||||
|                 case GGML_OP_POOL_AVG:   drow[i] = 0;        break; | ||||
| @@ -14770,10 +14768,11 @@ static void ggml_compute_forward_pool_1d_sk_p0( | ||||
|                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); | ||||
|             } | ||||
|             for (int ki = 0; ki < k; ++ki) { | ||||
|                 const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); | ||||
|                 switch (op) { | ||||
|                     case GGML_OP_POOL_AVG:                          drow[i] += srow[j]; break; | ||||
|                     case GGML_OP_POOL_MAX:   if (srow[j] > drow[i]) drow[i]  = srow[j]; break; | ||||
|                     case GGML_OP_POOL_COUNT:                        GGML_ABORT("fatal error"); | ||||
|                     case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break; | ||||
|                     case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break; | ||||
|                     case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error"); | ||||
|                 } | ||||
|                 ++j; | ||||
|             } | ||||
| @@ -14814,7 +14813,7 @@ static void ggml_compute_forward_pool_2d( | ||||
|  | ||||
|     const struct ggml_tensor * src = dst->src[0]; | ||||
|  | ||||
|     GGML_ASSERT(src->type == GGML_TYPE_F32); | ||||
|     assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); | ||||
|  | ||||
|     if (params->ith != 0) { | ||||
|         return; | ||||
| @@ -14857,14 +14856,15 @@ static void ggml_compute_forward_pool_2d( | ||||
|  | ||||
|                 for (int ky = 0; ky < k1; ++ky) { | ||||
|                     if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; | ||||
|                     const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); | ||||
|                     const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); | ||||
|                     for (int kx = 0; kx < k0; ++kx) { | ||||
|                         int j = ix + kx; | ||||
|                         if (j < 0 || j >= src->ne[0]) continue; | ||||
|                         const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); | ||||
|                         switch (op) { | ||||
|                             case GGML_OP_POOL_AVG:                     *out += srow[j]; break; | ||||
|                             case GGML_OP_POOL_MAX: if (srow[j] > *out) *out  = srow[j]; break; | ||||
|                             case GGML_OP_POOL_COUNT:                GGML_ABORT("fatal error"); | ||||
|                             case GGML_OP_POOL_AVG:                     *out += srow_j; break; | ||||
|                             case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break; | ||||
|                             case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error"); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user