mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : fix out-of-bounds access + inc concurrency nodes (#2416)
* metal : fix out-of-bounds access + style changes * metal : increase concurrency nodes to 2*GGML_MAX_NODES
This commit is contained in:
		
							
								
								
									
										49
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -7,6 +7,11 @@ | |||||||
| #import <Metal/Metal.h> | #import <Metal/Metal.h> | ||||||
| #import <MetalPerformanceShaders/MetalPerformanceShaders.h> | #import <MetalPerformanceShaders/MetalPerformanceShaders.h> | ||||||
|  |  | ||||||
|  | #undef MIN | ||||||
|  | #undef MAX | ||||||
|  | #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||||
|  | #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||||
|  |  | ||||||
| #ifdef GGML_METAL_NDEBUG | #ifdef GGML_METAL_NDEBUG | ||||||
| #define metal_printf(...) | #define metal_printf(...) | ||||||
| #else | #else | ||||||
| @@ -15,6 +20,8 @@ | |||||||
|  |  | ||||||
| #define UNUSED(x) (void)(x) | #define UNUSED(x) (void)(x) | ||||||
|  |  | ||||||
|  | #define GGML_MAX_CONCUR (2*GGML_MAX_NODES) | ||||||
|  |  | ||||||
| struct ggml_metal_buffer { | struct ggml_metal_buffer { | ||||||
|     const char * name; |     const char * name; | ||||||
|  |  | ||||||
| @@ -36,7 +43,7 @@ struct ggml_metal_context { | |||||||
|     int n_buffers; |     int n_buffers; | ||||||
|     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; |     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; | ||||||
|  |  | ||||||
|     int concur_list[GGML_MAX_NODES]; |     int concur_list[GGML_MAX_CONCUR]; | ||||||
|     int concur_list_len; |     int concur_list_len; | ||||||
|  |  | ||||||
|     // custom kernels |     // custom kernels | ||||||
| @@ -370,10 +377,10 @@ void ggml_metal_graph_find_concurrency( | |||||||
|         struct ggml_metal_context * ctx, |         struct ggml_metal_context * ctx, | ||||||
|         struct ggml_cgraph * gf) { |         struct ggml_cgraph * gf) { | ||||||
|     int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time |     int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time | ||||||
|     int nodes_unused[GGML_MAX_NODES]; |     int nodes_unused[GGML_MAX_CONCUR]; | ||||||
|  |  | ||||||
|     for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;} |     for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; } | ||||||
|     for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;} |     for (int i = 0; i < gf->n_nodes;     i++) { nodes_unused[i]     = 1; } | ||||||
|     ctx->concur_list_len = 0; |     ctx->concur_list_len = 0; | ||||||
|  |  | ||||||
|     int n_left    = gf->n_nodes; |     int n_left    = gf->n_nodes; | ||||||
| @@ -386,21 +393,33 @@ void ggml_metal_graph_find_concurrency( | |||||||
|         for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { |         for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { | ||||||
|             if (nodes_unused[i]) { |             if (nodes_unused[i]) { | ||||||
|                 // if the requirements for gf->nodes[i] are satisfied |                 // if the requirements for gf->nodes[i] are satisfied | ||||||
|                 int exe_flag=1; |                 int exe_flag = 1; | ||||||
|  |  | ||||||
|                 // scan all srcs |                 // scan all srcs | ||||||
|                 for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { |                 for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { | ||||||
|                     struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; |                     struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; | ||||||
|                     if (src_cur) { |                     if (src_cur) { | ||||||
|                         // if is leaf nodes it's satisfied. |                         // if is leaf nodes it's satisfied. | ||||||
|                         if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;} |                         // TODO: ggml_is_leaf() | ||||||
|  |                         if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) { | ||||||
|  |                             continue; | ||||||
|  |                         } | ||||||
|  |  | ||||||
|                         // otherwise this src should be the output from previous nodes. |                         // otherwise this src should be the output from previous nodes. | ||||||
|                         int is_found = 0; |                         int is_found = 0; | ||||||
|  |  | ||||||
|                         // scan 2*search_depth back because we inserted barrier. |                         // scan 2*search_depth back because we inserted barrier. | ||||||
|                         for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { |                         //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { | ||||||
|                             if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;} |                         for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) { | ||||||
|  |                             if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) { | ||||||
|  |                                 is_found = 1; | ||||||
|  |                                 break; | ||||||
|  |                             } | ||||||
|  |                         } | ||||||
|  |                         if (is_found == 0) { | ||||||
|  |                             exe_flag = 0; | ||||||
|  |                             break; | ||||||
|                         } |                         } | ||||||
|                         if (is_found == 0) {exe_flag = 0; break;} |  | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 if (exe_flag) { |                 if (exe_flag) { | ||||||
| @@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency( | |||||||
|                             if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ |                             if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ | ||||||
|                                 ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { |                                 ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { | ||||||
|                                 continue; |                                 continue; | ||||||
|                             } else { |  | ||||||
|                                 exe_flag = 0; |  | ||||||
|                             } |                             } | ||||||
|  |  | ||||||
|  |                             exe_flag = 0; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
| @@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency( | |||||||
|         ctx->concur_list[level_pos + concurrency] = -1; |         ctx->concur_list[level_pos + concurrency] = -1; | ||||||
|         ctx->concur_list_len++; |         ctx->concur_list_len++; | ||||||
|         // jump all sorted nodes at nodes_bak |         // jump all sorted nodes at nodes_bak | ||||||
|         while (!nodes_unused[n_start]) {n_start++;} |         while (!nodes_unused[n_start]) { | ||||||
|  |             n_start++; | ||||||
|  |         } | ||||||
|         level_pos += concurrency + 1; |         level_pos += concurrency + 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (ctx->concur_list_len > GGML_MAX_NODES) { |     if (ctx->concur_list_len > GGML_MAX_CONCUR) { | ||||||
|         fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__); |         fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -453,7 +474,7 @@ void ggml_metal_graph_compute( | |||||||
|     // else fallback to serial dispatch |     // else fallback to serial dispatch | ||||||
|     MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; |     MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; | ||||||
|  |  | ||||||
|     const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES; |     const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; | ||||||
|  |  | ||||||
|     const int n_nodes  = has_concur ? ctx->concur_list_len      : gf->n_nodes; |     const int n_nodes  = has_concur ? ctx->concur_list_len      : gf->n_nodes; | ||||||
|     edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; |     edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov