mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	
							
								
								
									
										57
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -1419,34 +1419,35 @@ void ggml_metal_graph_compute( | |||||||
|                                 default: GGML_ASSERT(false); |                                 default: GGML_ASSERT(false); | ||||||
|                             }; |                             }; | ||||||
|  |  | ||||||
|                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0]; |                             [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0]; | ||||||
|                             [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1]; |                             [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1]; | ||||||
|                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:2]; |                             [encoder setBuffer:id_dst      offset:offs_dst         atIndex:2]; | ||||||
|                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:3]; |                             [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:3]; | ||||||
|                             [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:4]; |                             [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:4]; | ||||||
|                             [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:5]; |                             [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:5]; | ||||||
|                             [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:6]; |                             [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:6]; | ||||||
|                             [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:7]; |                             [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:7]; | ||||||
|                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:8]; |                             [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:8]; | ||||||
|                             [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:9]; |                             [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:9]; | ||||||
|                             [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:10]; |                             [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:10]; | ||||||
|                             [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:11]; |                             [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:11]; | ||||||
|                             [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:12]; |                             [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:12]; | ||||||
|                             [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:13]; |                             [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:13]; | ||||||
|                             [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:14]; |                             [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:14]; | ||||||
|                             [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:15]; |                             [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:15]; | ||||||
|                             [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:16]; |                             [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:16]; | ||||||
|                             [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:17]; |                             [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:17]; | ||||||
|                             [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:18]; |                             [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:18]; | ||||||
|                             [encoder setBytes:&n_past  length:sizeof(     int) atIndex:19]; |                             [encoder setBytes:&n_past      length:sizeof(     int) atIndex:19]; | ||||||
|                             [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:20]; |                             [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:20]; | ||||||
|                             [encoder setBytes:&mode    length:sizeof(     int) atIndex:21]; |                             [encoder setBytes:&mode        length:sizeof(     int) atIndex:21]; | ||||||
|                             [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22]; |                             [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:22]; | ||||||
|                             [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; |                             [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23]; | ||||||
|                             [encoder setBytes:&ext_factor  length:sizeof(float) atIndex:24]; |                             [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24]; | ||||||
|                             [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25]; |                             [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25]; | ||||||
|                             [encoder setBytes:&beta_fast   length:sizeof(float) atIndex:26]; |                             [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26]; | ||||||
|                             [encoder setBytes:&beta_slow   length:sizeof(float) atIndex:27]; |                             [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27]; | ||||||
|  |                             [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28]; | ||||||
|  |  | ||||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; |                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||||
|                         } break; |                         } break; | ||||||
|   | |||||||
| @@ -1070,20 +1070,20 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { | |||||||
| // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. | // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. | ||||||
| static void rope_yarn( | static void rope_yarn( | ||||||
|     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, |     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, | ||||||
|     float * cos_theta, float * sin_theta |     thread float * cos_theta, thread float * sin_theta | ||||||
| ) { | ) { | ||||||
|     // Get n-d rotational scaling corrected for extrapolation |     // Get n-d rotational scaling corrected for extrapolation | ||||||
|     float theta_interp = freq_scale * theta_extrap; |     float theta_interp = freq_scale * theta_extrap; | ||||||
|     float theta = theta_interp; |     float theta = theta_interp; | ||||||
|     if (ext_factor != 0.0f) { |     if (ext_factor != 0.0f) { | ||||||
|         ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; |         float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; | ||||||
|         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; |         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; | ||||||
|  |  | ||||||
|         // Get n-d magnitude scaling corrected for interpolation |         // Get n-d magnitude scaling corrected for interpolation | ||||||
|         mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); |         mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); | ||||||
|     } |     } | ||||||
|     *cos_theta = cosf(theta) * mscale; |     *cos_theta = cos(theta) * mscale; | ||||||
|     *sin_theta = sinf(theta) * mscale; |     *sin_theta = sin(theta) * mscale; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get | // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get | ||||||
| @@ -1123,8 +1123,13 @@ typedef void (rope_t)( | |||||||
|         constant         int & n_past, |         constant         int & n_past, | ||||||
|         constant         int & n_dims, |         constant         int & n_dims, | ||||||
|         constant         int & mode, |         constant         int & mode, | ||||||
|  |         constant         int & n_orig_ctx, | ||||||
|         constant       float & freq_base, |         constant       float & freq_base, | ||||||
|         constant       float & freq_scale, |         constant       float & freq_scale, | ||||||
|  |         constant       float & ext_factor, | ||||||
|  |         constant       float & attn_factor, | ||||||
|  |         constant       float & beta_fast, | ||||||
|  |         constant       float & beta_slow, | ||||||
|         uint  tiitg[[thread_index_in_threadgroup]], |         uint  tiitg[[thread_index_in_threadgroup]], | ||||||
|         uint3 tptg[[threads_per_threadgroup]], |         uint3 tptg[[threads_per_threadgroup]], | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]]); |         uint3 tgpig[[threadgroup_position_in_grid]]); | ||||||
| @@ -1153,6 +1158,7 @@ kernel void kernel_rope( | |||||||
|         constant         int & n_past, |         constant         int & n_past, | ||||||
|         constant         int & n_dims, |         constant         int & n_dims, | ||||||
|         constant         int & mode, |         constant         int & mode, | ||||||
|  |         constant         int & n_orig_ctx, | ||||||
|         constant       float & freq_base, |         constant       float & freq_base, | ||||||
|         constant       float & freq_scale, |         constant       float & freq_scale, | ||||||
|         constant       float & ext_factor, |         constant       float & ext_factor, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov