mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : implement q5_0 and q5_1 kernels (#3648)
* metal : implement dequantize_q5_0 * metal : block_q_n_dot_y for block_q5_0 (broken) * metal : revert unnecessary change * metal : implement dequantize_q5_1 * metal : block_q_n_dot_y for q5_1 (broken) * metal : fix block_q_n_dot_y * minor : spaces / formatting --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										47
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										47
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -73,6 +73,8 @@ struct ggml_metal_context { | |||||||
|     GGML_METAL_DECL_KERNEL(get_rows_f16); |     GGML_METAL_DECL_KERNEL(get_rows_f16); | ||||||
|     GGML_METAL_DECL_KERNEL(get_rows_q4_0); |     GGML_METAL_DECL_KERNEL(get_rows_q4_0); | ||||||
|     GGML_METAL_DECL_KERNEL(get_rows_q4_1); |     GGML_METAL_DECL_KERNEL(get_rows_q4_1); | ||||||
|  |     GGML_METAL_DECL_KERNEL(get_rows_q5_0); | ||||||
|  |     GGML_METAL_DECL_KERNEL(get_rows_q5_1); | ||||||
|     GGML_METAL_DECL_KERNEL(get_rows_q8_0); |     GGML_METAL_DECL_KERNEL(get_rows_q8_0); | ||||||
|     GGML_METAL_DECL_KERNEL(get_rows_q2_K); |     GGML_METAL_DECL_KERNEL(get_rows_q2_K); | ||||||
|     GGML_METAL_DECL_KERNEL(get_rows_q3_K); |     GGML_METAL_DECL_KERNEL(get_rows_q3_K); | ||||||
| @@ -87,6 +89,8 @@ struct ggml_metal_context { | |||||||
|     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); |     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); |     GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); |     GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); | ||||||
|  |     GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); | ||||||
|  |     GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); |     GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); |     GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); |     GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); | ||||||
| @@ -97,6 +101,8 @@ struct ggml_metal_context { | |||||||
|     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); |     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); |     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); |     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); | ||||||
|  |     GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); | ||||||
|  |     GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); |     GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); |     GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); |     GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); | ||||||
| @@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |||||||
|         GGML_METAL_ADD_KERNEL(get_rows_f16); |         GGML_METAL_ADD_KERNEL(get_rows_f16); | ||||||
|         GGML_METAL_ADD_KERNEL(get_rows_q4_0); |         GGML_METAL_ADD_KERNEL(get_rows_q4_0); | ||||||
|         GGML_METAL_ADD_KERNEL(get_rows_q4_1); |         GGML_METAL_ADD_KERNEL(get_rows_q4_1); | ||||||
|  |         GGML_METAL_ADD_KERNEL(get_rows_q5_0); | ||||||
|  |         GGML_METAL_ADD_KERNEL(get_rows_q5_1); | ||||||
|         GGML_METAL_ADD_KERNEL(get_rows_q8_0); |         GGML_METAL_ADD_KERNEL(get_rows_q8_0); | ||||||
|         GGML_METAL_ADD_KERNEL(get_rows_q2_K); |         GGML_METAL_ADD_KERNEL(get_rows_q2_K); | ||||||
|         GGML_METAL_ADD_KERNEL(get_rows_q3_K); |         GGML_METAL_ADD_KERNEL(get_rows_q3_K); | ||||||
| @@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |||||||
|         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); |         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); | ||||||
|         GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); |         GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); | ||||||
|         GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); |         GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); | ||||||
|  |         GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); | ||||||
|  |         GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); | ||||||
|         GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); |         GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); | ||||||
|         GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); |         GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); | ||||||
|         GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); |         GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); | ||||||
| @@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); |             GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); | ||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); |             GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); | ||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); |             GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); | ||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); |  | ||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); |             GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); | ||||||
|  |             GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); | ||||||
|  |             GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); | ||||||
|  |             GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); | ||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); |             GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); | ||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); |             GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); | ||||||
|             GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); |             GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); | ||||||
| @@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | |||||||
|     GGML_METAL_DEL_KERNEL(get_rows_f16); |     GGML_METAL_DEL_KERNEL(get_rows_f16); | ||||||
|     GGML_METAL_DEL_KERNEL(get_rows_q4_0); |     GGML_METAL_DEL_KERNEL(get_rows_q4_0); | ||||||
|     GGML_METAL_DEL_KERNEL(get_rows_q4_1); |     GGML_METAL_DEL_KERNEL(get_rows_q4_1); | ||||||
|  |     GGML_METAL_DEL_KERNEL(get_rows_q5_0); | ||||||
|  |     GGML_METAL_DEL_KERNEL(get_rows_q5_1); | ||||||
|     GGML_METAL_DEL_KERNEL(get_rows_q8_0); |     GGML_METAL_DEL_KERNEL(get_rows_q8_0); | ||||||
|     GGML_METAL_DEL_KERNEL(get_rows_q2_K); |     GGML_METAL_DEL_KERNEL(get_rows_q2_K); | ||||||
|     GGML_METAL_DEL_KERNEL(get_rows_q3_K); |     GGML_METAL_DEL_KERNEL(get_rows_q3_K); | ||||||
| @@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | |||||||
|     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); |     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); | ||||||
|     GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); |     GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); | ||||||
|     GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); |     GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); | ||||||
|  |     GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); | ||||||
|  |     GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); | ||||||
|     GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); |     GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); | ||||||
|     GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); |     GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); | ||||||
|     GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); |     GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); | ||||||
| @@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | |||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); |         GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); | ||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); |         GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); | ||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); |         GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); | ||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); |  | ||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); |         GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); | ||||||
|  |         GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); | ||||||
|  |         GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); | ||||||
|  |         GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); | ||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); |         GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); | ||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); |         GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); | ||||||
|         GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); |         GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); | ||||||
| @@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute( | |||||||
|                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break; |                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break; | ||||||
|                                     case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; |                                     case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; | ||||||
|                                     case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; |                                     case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; | ||||||
|  |                                     case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; | ||||||
|  |                                     case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; | ||||||
|                                     case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; |                                     case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; | ||||||
|                                     case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; |                                     case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; | ||||||
|                                     case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; |                                     case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; | ||||||
| @@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute( | |||||||
|                                             nth1 = 8; |                                             nth1 = 8; | ||||||
|                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; |                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; | ||||||
|                                         } break; |                                         } break; | ||||||
|  |                                     case GGML_TYPE_Q5_0: | ||||||
|  |                                         { | ||||||
|  |                                             GGML_ASSERT(ne02 == 1); | ||||||
|  |                                             GGML_ASSERT(ne12 == 1); | ||||||
|  |  | ||||||
|  |                                             nth0 = 8; | ||||||
|  |                                             nth1 = 8; | ||||||
|  |                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; | ||||||
|  |                                         } break; | ||||||
|  |                                     case GGML_TYPE_Q5_1: | ||||||
|  |                                         { | ||||||
|  |                                             GGML_ASSERT(ne02 == 1); | ||||||
|  |                                             GGML_ASSERT(ne12 == 1); | ||||||
|  |  | ||||||
|  |                                             nth0 = 8; | ||||||
|  |                                             nth1 = 8; | ||||||
|  |                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; | ||||||
|  |                                         } break; | ||||||
|                                     case GGML_TYPE_Q8_0: |                                     case GGML_TYPE_Q8_0: | ||||||
|                                         { |                                         { | ||||||
|                                             GGML_ASSERT(ne02 == 1); |                                             GGML_ASSERT(ne02 == 1); | ||||||
| @@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute( | |||||||
|                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16]; |                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16]; | ||||||
|                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17]; |                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17]; | ||||||
|  |  | ||||||
|                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || |                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || | ||||||
|  |                                     src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || | ||||||
|                                     src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { |                                     src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { | ||||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|                                 } |                                 } | ||||||
| @@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute( | |||||||
|                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break; |                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break; | ||||||
|                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; |                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; | ||||||
|                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; |                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; | ||||||
|  |                                 case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; | ||||||
|  |                                 case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; | ||||||
|                                 case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; |                                 case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; | ||||||
|                                 case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; |                                 case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; | ||||||
|                                 case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; |                                 case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; | ||||||
|   | |||||||
							
								
								
									
										163
									
								
								ggml-metal.metal
									
									
									
									
									
								
							
							
						
						
									
										163
									
								
								ggml-metal.metal
									
									
									
									
									
								
							| @@ -18,6 +18,21 @@ typedef struct { | |||||||
|     uint8_t qs[QK4_1 / 2];  // nibbles / quants |     uint8_t qs[QK4_1 / 2];  // nibbles / quants | ||||||
| } block_q4_1; | } block_q4_1; | ||||||
|  |  | ||||||
|  | #define QK5_0 32 | ||||||
|  | typedef struct { | ||||||
|  |     half d;                // delta | ||||||
|  |     uint8_t qh[4];         // 5-th bit of quants | ||||||
|  |     uint8_t qs[QK5_0 / 2]; // nibbles / quants | ||||||
|  | } block_q5_0; | ||||||
|  |  | ||||||
|  | #define QK5_1 32 | ||||||
|  | typedef struct { | ||||||
|  |     half d;                 // delta | ||||||
|  |     half m;                 // min | ||||||
|  |     uint8_t qh[4];          // 5-th bit of quants | ||||||
|  |     uint8_t qs[QK5_1 / 2];  // nibbles / quants | ||||||
|  | } block_q5_1; | ||||||
|  |  | ||||||
| #define QK8_0 32 | #define QK8_0 32 | ||||||
| typedef struct { | typedef struct { | ||||||
|     half    d;         // delta |     half    d;         // delta | ||||||
| @@ -399,8 +414,11 @@ kernel void kernel_rms_norm( | |||||||
| // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | ||||||
| inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { | inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { | ||||||
|     float d = qb_curr->d; |     float d = qb_curr->d; | ||||||
|  |  | ||||||
|     float2 acc = 0.f; |     float2 acc = 0.f; | ||||||
|  |  | ||||||
|     device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); |     device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); | ||||||
|  |  | ||||||
|     for (int i = 0; i < 8; i+=2) { |     for (int i = 0; i < 8; i+=2) { | ||||||
|         acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) |         acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) | ||||||
|                 + yl[i + 1] * (qs[i / 2] & 0x0F00); |                 + yl[i + 1] * (qs[i / 2] & 0x0F00); | ||||||
| @@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre | |||||||
| inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { | inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { | ||||||
|     float d = qb_curr->d; |     float d = qb_curr->d; | ||||||
|     float m = qb_curr->m; |     float m = qb_curr->m; | ||||||
|     device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); |  | ||||||
|     float2 acc = 0.f; |     float2 acc = 0.f; | ||||||
|  |  | ||||||
|  |     device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); | ||||||
|  |  | ||||||
|     for (int i = 0; i < 8; i+=2) { |     for (int i = 0; i < 8; i+=2) { | ||||||
|         acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) |         acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) | ||||||
|                 + yl[i + 1] * (qs[i / 2] & 0x0F00); |                 + yl[i + 1] * (qs[i / 2] & 0x0F00); | ||||||
| @@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre | |||||||
|     return d * (acc[0] + acc[1]) + sumy * m; |     return d * (acc[0] + acc[1]) + sumy * m; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) | ||||||
|  | // il indicates where the q5 quants begin (0 or QK5_0/4) | ||||||
|  | // we assume that the yl's have been multiplied with the appropriate scale factor | ||||||
|  | // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | ||||||
|  | inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { | ||||||
|  |     float d = qb_curr->d; | ||||||
|  |  | ||||||
|  |     float2 acc = 0.f; | ||||||
|  |  | ||||||
|  |     device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2); | ||||||
|  |            const uint32_t   qh = *((device const uint32_t *)qb_curr->qh); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < 8; i+=2) { | ||||||
|  |         acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010)) | ||||||
|  |                 + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000)); | ||||||
|  |         acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) | ||||||
|  |                 + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); | ||||||
|  |     } | ||||||
|  |     return d * (sumy * -16.f + acc[0] + acc[1]); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) | ||||||
|  | // il indicates where the q5 quants begin (0 or QK5_1/4) | ||||||
|  | // we assume that the yl's have been multiplied with the appropriate scale factor | ||||||
|  | // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | ||||||
|  | inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { | ||||||
|  |     float d = qb_curr->d; | ||||||
|  |     float m = qb_curr->m; | ||||||
|  |  | ||||||
|  |     float2 acc = 0.f; | ||||||
|  |  | ||||||
|  |     device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2); | ||||||
|  |            const uint32_t   qh = *((device const uint32_t *)qb_curr->qh); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < 8; i+=2) { | ||||||
|  |         acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010)) | ||||||
|  |                 + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000)); | ||||||
|  |         acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) | ||||||
|  |                 + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); | ||||||
|  |     } | ||||||
|  |     return d * (acc[0] + acc[1]) + sumy * m; | ||||||
|  | } | ||||||
|  |  | ||||||
| // putting them in the kernel cause a significant performance penalty | // putting them in the kernel cause a significant performance penalty | ||||||
| #define N_DST 4        // each SIMD group works on 4 rows | #define N_DST 4        // each SIMD group works on 4 rows | ||||||
| #define N_SIMDGROUP 2  // number of SIMD groups in a thread group | #define N_SIMDGROUP 2  // number of SIMD groups in a thread group | ||||||
| @@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32( | |||||||
|      mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); |      mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | kernel void kernel_mul_mv_q5_0_f32( | ||||||
|  |         device const  void * src0, | ||||||
|  |         device const float * src1, | ||||||
|  |         device       float * dst, | ||||||
|  |         constant   int64_t & ne00, | ||||||
|  |         constant   int64_t & ne01[[buffer(4)]], | ||||||
|  |         constant   int64_t & ne02[[buffer(5)]], | ||||||
|  |         constant   int64_t & ne10[[buffer(9)]], | ||||||
|  |         constant   int64_t & ne12[[buffer(11)]], | ||||||
|  |         constant   int64_t & ne0[[buffer(15)]], | ||||||
|  |         constant   int64_t & ne1[[buffer(16)]], | ||||||
|  |         constant   uint    & gqa[[buffer(17)]], | ||||||
|  |         uint3 tgpig[[threadgroup_position_in_grid]], | ||||||
|  |         uint  tiisg[[thread_index_in_simdgroup]], | ||||||
|  |         uint  sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|  |     mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | kernel void kernel_mul_mv_q5_1_f32( | ||||||
|  |         device const  void * src0, | ||||||
|  |         device const float * src1, | ||||||
|  |         device       float * dst, | ||||||
|  |         constant   int64_t & ne00, | ||||||
|  |         constant   int64_t & ne01[[buffer(4)]], | ||||||
|  |         constant   int64_t & ne02[[buffer(5)]], | ||||||
|  |         constant   int64_t & ne10[[buffer(9)]], | ||||||
|  |         constant   int64_t & ne12[[buffer(11)]], | ||||||
|  |         constant   int64_t & ne0[[buffer(15)]], | ||||||
|  |         constant   int64_t & ne1[[buffer(16)]], | ||||||
|  |         constant   uint    & gqa[[buffer(17)]], | ||||||
|  |         uint3 tgpig[[threadgroup_position_in_grid]], | ||||||
|  |         uint  tiisg[[thread_index_in_simdgroup]], | ||||||
|  |         uint  sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|  |     mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| #define NB_Q8_0 8 | #define NB_Q8_0 8 | ||||||
|  |  | ||||||
| kernel void kernel_mul_mv_q8_0_f32( | kernel void kernel_mul_mv_q8_0_f32( | ||||||
| @@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename type4x4> | ||||||
|  | void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { | ||||||
|  |     device const uint16_t * qs = ((device const uint16_t *)xb + 3); | ||||||
|  |     const float d = xb->d; | ||||||
|  |     const float md = -16.h * xb->d; | ||||||
|  |     const ushort mask = il ? 0x00F0 : 0x000F; | ||||||
|  |  | ||||||
|  |     const uint32_t qh = *((device const uint32_t *)xb->qh); | ||||||
|  |  | ||||||
|  |     const int x_mv = il ? 4 : 0; | ||||||
|  |  | ||||||
|  |     const int gh_mv = il ? 12 : 0; | ||||||
|  |     const int gh_bk = il ?  0 : 4; | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < 8; i++) { | ||||||
|  |         // extract the 5-th bits for x0 and x1 | ||||||
|  |         const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10; | ||||||
|  |         const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | ||||||
|  |  | ||||||
|  |         // combine the 4-bits from qs with the 5th bit | ||||||
|  |         const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0); | ||||||
|  |         const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | ||||||
|  |  | ||||||
|  |         reg[i/2][2*(i%2)+0] = d * x0 + md; | ||||||
|  |         reg[i/2][2*(i%2)+1] = d * x1 + md; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename type4x4> | ||||||
|  | void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { | ||||||
|  |     device const uint16_t * qs = ((device const uint16_t *)xb + 4); | ||||||
|  |     const float d = xb->d; | ||||||
|  |     const float m = xb->m; | ||||||
|  |     const ushort mask = il ? 0x00F0 : 0x000F; | ||||||
|  |  | ||||||
|  |     const uint32_t qh = *((device const uint32_t *)xb->qh); | ||||||
|  |  | ||||||
|  |     const int x_mv = il ? 4 : 0; | ||||||
|  |  | ||||||
|  |     const int gh_mv = il ? 12 : 0; | ||||||
|  |     const int gh_bk = il ?  0 : 4; | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < 8; i++) { | ||||||
|  |         // extract the 5-th bits for x0 and x1 | ||||||
|  |         const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10; | ||||||
|  |         const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | ||||||
|  |  | ||||||
|  |         // combine the 4-bits from qs with the 5th bit | ||||||
|  |         const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0); | ||||||
|  |         const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | ||||||
|  |  | ||||||
|  |         reg[i/2][2*(i%2)+0] = d * x0 + m; | ||||||
|  |         reg[i/2][2*(i%2)+1] = d * x1 + m; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template <typename type4x4> | template <typename type4x4> | ||||||
| void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { | void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { | ||||||
|     device const int8_t * qs = ((device const int8_t *)xb->qs); |     device const int8_t * qs = ((device const int8_t *)xb->qs); | ||||||
| @@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows | |||||||
| template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>; | template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>; | ||||||
| template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; | template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; | ||||||
| template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; | template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; | ||||||
|  | template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>; | ||||||
|  | template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>; | ||||||
| template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>; | template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>; | ||||||
| template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>; | template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>; | ||||||
| template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>; | template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>; | ||||||
| @@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<f | |||||||
| template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>; | template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>; | ||||||
| template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>; | template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>; | ||||||
| template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>; | template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>; | ||||||
|  | template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2,     dequantize_q5_0>; | ||||||
|  | template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2,     dequantize_q5_1>; | ||||||
| template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2,     dequantize_q8_0>; | template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2,     dequantize_q8_0>; | ||||||
| template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; | template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; | ||||||
| template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; | template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jhen-Jie Hong
					Jhen-Jie Hong