mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : extend ggml_soft_max_ext() to support n_seq dim
This commit is contained in:
		| @@ -454,6 +454,8 @@ typedef struct { | ||||
|     int64_t  ne00; | ||||
|     int64_t  ne01; | ||||
|     int64_t  ne02; | ||||
|     uint64_t nb11; | ||||
|     uint64_t nb12; | ||||
|     float    scale; | ||||
|     float    max_bias; | ||||
|     float    m0; | ||||
|   | ||||
| @@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node( | ||||
|                 memcpy(&scale,    ((const int32_t *) dst->op_params) + 0, sizeof(scale)); | ||||
|                 memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); | ||||
|  | ||||
|                 const int64_t nrows_x = ggml_nrows(src0); | ||||
|                 const int64_t nrows_y = src0->ne[1]; | ||||
|  | ||||
|                 const uint32_t n_head      = nrows_x/nrows_y; | ||||
|                 const uint32_t n_head      = src0->ne[2]; | ||||
|                 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||||
|  | ||||
|                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||
| @@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node( | ||||
|                     /*.ne00        =*/ ne00, | ||||
|                     /*.ne01        =*/ ne01, | ||||
|                     /*.ne02        =*/ ne02, | ||||
|                     /*.nb11        =*/ nb11, | ||||
|                     /*.nb12        =*/ nb12, | ||||
|                     /*.scale       =*/ scale, | ||||
|                     /*.max_bias    =*/ max_bias, | ||||
|                     /*.m0          =*/ m0, | ||||
|   | ||||
| @@ -1263,7 +1263,7 @@ kernel void kernel_soft_max( | ||||
|     const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); | ||||
|  | ||||
|     device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); | ||||
|     device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*args.ne00 : nullptr; | ||||
|     device const     T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr; | ||||
|     device       float * pdst  = (device       float *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); | ||||
|  | ||||
|     float slope = 1.0f; | ||||
| @@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4( | ||||
|     const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); | ||||
|  | ||||
|     device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; | ||||
|     device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*args.ne00/4 : nullptr; | ||||
|     device const      T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr; | ||||
|     device       float4 * pdst4 = (device       float4 *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; | ||||
|  | ||||
|     float slope = 1.0f; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov