mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : extend ggml_soft_max_ext() to support n_seq dim
This commit is contained in:
		| @@ -454,6 +454,8 @@ typedef struct { | |||||||
|     int64_t  ne00; |     int64_t  ne00; | ||||||
|     int64_t  ne01; |     int64_t  ne01; | ||||||
|     int64_t  ne02; |     int64_t  ne02; | ||||||
|  |     uint64_t nb11; | ||||||
|  |     uint64_t nb12; | ||||||
|     float    scale; |     float    scale; | ||||||
|     float    max_bias; |     float    max_bias; | ||||||
|     float    m0; |     float    m0; | ||||||
|   | |||||||
| @@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node( | |||||||
|                 memcpy(&scale,    ((const int32_t *) dst->op_params) + 0, sizeof(scale)); |                 memcpy(&scale,    ((const int32_t *) dst->op_params) + 0, sizeof(scale)); | ||||||
|                 memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); |                 memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); | ||||||
|  |  | ||||||
|                 const int64_t nrows_x = ggml_nrows(src0); |                 const uint32_t n_head      = src0->ne[2]; | ||||||
|                 const int64_t nrows_y = src0->ne[1]; |  | ||||||
|  |  | ||||||
|                 const uint32_t n_head      = nrows_x/nrows_y; |  | ||||||
|                 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); |                 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||||||
|  |  | ||||||
|                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); |                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||||
| @@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node( | |||||||
|                     /*.ne00        =*/ ne00, |                     /*.ne00        =*/ ne00, | ||||||
|                     /*.ne01        =*/ ne01, |                     /*.ne01        =*/ ne01, | ||||||
|                     /*.ne02        =*/ ne02, |                     /*.ne02        =*/ ne02, | ||||||
|  |                     /*.nb11        =*/ nb11, | ||||||
|  |                     /*.nb12        =*/ nb12, | ||||||
|                     /*.scale       =*/ scale, |                     /*.scale       =*/ scale, | ||||||
|                     /*.max_bias    =*/ max_bias, |                     /*.max_bias    =*/ max_bias, | ||||||
|                     /*.m0          =*/ m0, |                     /*.m0          =*/ m0, | ||||||
|   | |||||||
| @@ -1263,7 +1263,7 @@ kernel void kernel_soft_max( | |||||||
|     const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); |     const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); | ||||||
|  |  | ||||||
|     device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); |     device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); | ||||||
|     device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*args.ne00 : nullptr; |     device const     T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr; | ||||||
|     device       float * pdst  = (device       float *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); |     device       float * pdst  = (device       float *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); | ||||||
|  |  | ||||||
|     float slope = 1.0f; |     float slope = 1.0f; | ||||||
| @@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4( | |||||||
|     const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); |     const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); | ||||||
|  |  | ||||||
|     device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; |     device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; | ||||||
|     device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*args.ne00/4 : nullptr; |     device const      T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr; | ||||||
|     device       float4 * pdst4 = (device       float4 *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; |     device       float4 * pdst4 = (device       float4 *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; | ||||||
|  |  | ||||||
|     float slope = 1.0f; |     float slope = 1.0f; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov