mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : alibi for arbitrary number of heads (#3426)
This commit is contained in:
		| @@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute( | |||||||
|                             float max_bias; |                             float max_bias; | ||||||
|                             memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); |                             memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); | ||||||
|  |  | ||||||
|                             if (__builtin_popcount(n_head) != 1) { |  | ||||||
|                                 GGML_ASSERT(false && "only power-of-two n_head implemented"); |  | ||||||
|                             } |  | ||||||
|  |  | ||||||
|                             const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); |                             const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); | ||||||
|                             const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); |                             const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); | ||||||
|  |                             const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); | ||||||
|  |  | ||||||
|                             [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; |                             [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; | ||||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||||
| @@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute( | |||||||
|                             [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15]; |                             [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15]; | ||||||
|                             [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16]; |                             [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16]; | ||||||
|                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17]; |                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17]; | ||||||
|                             [encoder setBytes:&m0  length:sizeof(    float) atIndex:18]; |                             [encoder setBytes:&m0   length:sizeof(   float) atIndex:18]; | ||||||
|  |                             [encoder setBytes:&m1   length:sizeof(   float) atIndex:19]; | ||||||
|  |                             [encoder setBytes:&n_heads_log2_floor   length:sizeof(int) atIndex:20]; | ||||||
|  |  | ||||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; |                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||||
|                         } break; |                         } break; | ||||||
|   | |||||||
| @@ -830,7 +830,9 @@ kernel void kernel_alibi_f32( | |||||||
|         constant  uint64_t & nb1, |         constant  uint64_t & nb1, | ||||||
|         constant  uint64_t & nb2, |         constant  uint64_t & nb2, | ||||||
|         constant  uint64_t & nb3, |         constant  uint64_t & nb3, | ||||||
|         constant      float & m0, |         constant     float & m0, | ||||||
|  |         constant     float & m1, | ||||||
|  |         constant       int & n_heads_log2_floor, | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |         uint3 tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint3 tpitg[[thread_position_in_threadgroup]], |         uint3 tpitg[[thread_position_in_threadgroup]], | ||||||
|         uint3   ntg[[threads_per_threadgroup]]) { |         uint3   ntg[[threads_per_threadgroup]]) { | ||||||
| @@ -846,7 +848,12 @@ kernel void kernel_alibi_f32( | |||||||
|     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); |     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); | ||||||
|  |  | ||||||
|     device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |     device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); | ||||||
|     float m_k = pow(m0, i2 + 1); |     float m_k; | ||||||
|  |     if (i2 < n_heads_log2_floor) { | ||||||
|  |         m_k = pow(m0, i2 + 1); | ||||||
|  |     } else { | ||||||
|  |         m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1); | ||||||
|  |     } | ||||||
|     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { |     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { | ||||||
|         device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); |         device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); | ||||||
|         dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); |         dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jiahao Li
					Jiahao Li