mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	initial implementation of delayed graph allocation
This commit is contained in:
		
							
								
								
									
										116
									
								
								ggml-backend.c
									
									
									
									
									
								
							
							
						
						
									
										116
									
								
								ggml-backend.c
									
									
									
									
									
								
							| @@ -373,29 +373,29 @@ void ggml_graph_splits_add_n_va(struct ggml_graph_splits * splits, struct ggml_t | ||||
|  | ||||
|     struct ggml_graph_split * split = &splits->splits[splits->n_splits]; | ||||
|  | ||||
|     // check if the split is on the same backend as the previous one | ||||
|     // FIXME: need to check all the inputs | ||||
|     if ((*inputs[0])->backend == ggml_get_ctx_backend(ctx)) { | ||||
|         if (splits->n_splits == 0) { | ||||
|             // always add the first split | ||||
|             int i = 0; | ||||
|             while (inputs[i] != NULL) { | ||||
|                 GGML_ASSERT(i < GGML_MAX_SPLIT_INPUTS); | ||||
|                 split->src_inputs[i] = *inputs[i]; | ||||
|                 split->dst_inputs[i] = *inputs[i]; | ||||
|                 i++; | ||||
|             } | ||||
|             split->src_inputs[i] = NULL; | ||||
|             split->dst_inputs[i] = NULL; | ||||
|         } else { | ||||
|             // add to the previous split | ||||
|             char name[GGML_MAX_NAME - 2]; | ||||
|             int n = vsnprintf(name, sizeof(name), fmt, args); | ||||
|             char new_name[GGML_MAX_NAME]; | ||||
|             snprintf(new_name, sizeof(new_name), "%.*s,%s", GGML_MAX_NAME - n - 2, splits->splits[splits->n_splits - 1].name, name); | ||||
|             strcpy(splits->splits[splits->n_splits - 1].name, new_name); | ||||
|             return; | ||||
|  | ||||
|     if (splits->n_splits == 0) { | ||||
|         // always add the first split | ||||
|         int i = 0; | ||||
|         while (inputs[i] != NULL) { | ||||
|             GGML_ASSERT(i < GGML_MAX_SPLIT_INPUTS); | ||||
|             split->src_inputs[i] = *inputs[i]; | ||||
|             split->dst_inputs[i] = *inputs[i]; | ||||
|             i++; | ||||
|         } | ||||
|         split->src_inputs[i] = NULL; | ||||
|         split->dst_inputs[i] = NULL; | ||||
|         split->ctx = ctx; | ||||
|     } | ||||
|     // check if the split is on the same context as the previous one | ||||
|     else if (splits->n_splits > 0 && splits->splits[splits->n_splits - 1].ctx == ctx) { | ||||
|         // add to the previous split | ||||
|         char name[GGML_MAX_NAME - 2]; | ||||
|         int n = vsnprintf(name, sizeof(name), fmt, args); | ||||
|         char new_name[GGML_MAX_NAME]; | ||||
|         snprintf(new_name, sizeof(new_name), "%.*s,%s", GGML_MAX_NAME - n - 2, splits->splits[splits->n_splits - 1].name, name); | ||||
|         strcpy(splits->splits[splits->n_splits - 1].name, new_name); | ||||
|         return; | ||||
|     } else { | ||||
|         // add a new split | ||||
|         int i = 0; | ||||
| @@ -403,6 +403,7 @@ void ggml_graph_splits_add_n_va(struct ggml_graph_splits * splits, struct ggml_t | ||||
|             GGML_ASSERT(i < GGML_MAX_SPLIT_INPUTS); | ||||
|             split->src_inputs[i] = *inputs[i]; | ||||
|             split->dst_inputs[i] = ggml_dup_tensor(ctx, *inputs[i]); | ||||
|             ggml_format_name(split->dst_inputs[i], "%s (split output)", split->src_inputs[i]->name); | ||||
|             // TODO: maybe support different layings in ggml_backend_cpy_tensor instead | ||||
|             for (int j = 0; j < GGML_MAX_DIMS; j++) { | ||||
|                 split->dst_inputs[i]->nb[j] = split->src_inputs[i]->nb[j]; | ||||
| @@ -413,6 +414,7 @@ void ggml_graph_splits_add_n_va(struct ggml_graph_splits * splits, struct ggml_t | ||||
|         } | ||||
|         split->src_inputs[i] = NULL; | ||||
|         split->dst_inputs[i] = NULL; | ||||
|         split->ctx = ctx; | ||||
|     } | ||||
|  | ||||
|     vsnprintf(split->name, GGML_MAX_NAME, fmt, args); | ||||
| @@ -493,7 +495,8 @@ void ggml_graph_splits_compute(struct ggml_graph_splits * splits) { | ||||
|         // copy the input tensor to the backend | ||||
|         uint64_t copy_start_us = ggml_time_us(); | ||||
|         for (int j = 0; split->src_inputs[j] != NULL; j++) { | ||||
|             //printf("\tcopying tensor %d (%s) (%lu bytes)\n", j, split->src_inputs[j]->name, ggml_nbytes(split->src_inputs[j])); | ||||
|             //printf("\tcopying tensor %d (%s) (%s -> %s) (%lu bytes)\n", j, split->src_inputs[j]->name, ggml_backend_name(split->src_inputs[j]->backend), ggml_backend_name(split->dst_inputs[j]->backend), ggml_nbytes(split->src_inputs[j])); | ||||
|             //printf("%p %p\n", split->src_inputs[j], split->dst_inputs[j]); | ||||
|             ggml_backend_tensor_copy(split->src_inputs[j], split->dst_inputs[j]); | ||||
|         } | ||||
|         // ggml_backend_synchronize(split->dst_inputs[0]->backend); | ||||
| @@ -705,32 +708,83 @@ void allocate_graph(struct ggml_cgraph * gf, struct ggml_buffer * buffer) { | ||||
|  | ||||
| #endif | ||||
|  | ||||
| void ggml_graph_allocate_tensors(struct ggml_cgraph * graph) { | ||||
|     ggml_graph_allocate_tensors_n(&graph, 1); | ||||
| void ggml_graph_allocate_tensors(struct ggml_cgraph * graph, struct ggml_context * ctx) { | ||||
|     ggml_graph_allocate_tensors_n(&graph, 1, ctx); | ||||
| } | ||||
|  | ||||
| void ggml_graph_allocate_tensors_n(struct ggml_cgraph ** graphs, int n_graphs) { | ||||
| static bool ggml_is_view(struct ggml_tensor * t) { | ||||
|     return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE || | ||||
|            t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY; | ||||
| } | ||||
|  | ||||
| void ggml_graph_allocate_tensors_n(struct ggml_cgraph ** graphs, int n_graphs, struct ggml_context * ctx) { | ||||
|     struct ggml_buffer * buffer = ggml_get_buffer(ctx); | ||||
|     for (int i = 0; i < n_graphs; i++) { | ||||
|         struct ggml_cgraph * graph = graphs[i]; | ||||
|         for (int j = 0; j < graph->n_leafs; j++) { | ||||
|             struct ggml_tensor * leaf = graph->leafs[j]; | ||||
|             GGML_ASSERT(leaf->backend == buffer->backend_buffer->backend); | ||||
|             if (leaf->data == NULL) { | ||||
|                 //printf("allocating leaf %s\n", leaf->name); | ||||
|                 ggml_backend_buffer_tensor_alloc(buffer->backend_buffer, leaf); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         for (int j = 0; j < graph->n_nodes; j++) { | ||||
|             struct ggml_tensor * node = graph->nodes[j]; | ||||
|             GGML_ASSERT(node->backend == buffer->backend_buffer->backend); | ||||
|             if (node->data == NULL) { | ||||
|                 if (ggml_is_view(node)) { | ||||
|                     size_t offset; | ||||
|                     memcpy(&offset, node->op_params, sizeof(size_t)); | ||||
|                     switch(node->op) { | ||||
|                         case GGML_OP_VIEW: | ||||
|                             //printf("view %s (%s), offset %zu\n", node->name, ggml_op_name(node->op), offset); | ||||
|                             node->data = (char *) node->src[0]->data + offset; | ||||
|                             break; | ||||
|                         case GGML_OP_RESHAPE: | ||||
|                         case GGML_OP_TRANSPOSE: | ||||
|                         case GGML_OP_PERMUTE: | ||||
|                             node->data = node->src[0]->data; | ||||
|                             break; | ||||
|                         case GGML_OP_CPY: | ||||
|                             node->data = node->src[1]->data; | ||||
|                             break; | ||||
|                         default: | ||||
|                             GGML_ASSERT(!"unknown view op"); | ||||
|                             break; | ||||
|                     } | ||||
|                 } else { | ||||
|                     //printf("allocating tensor %s\n", node->name); | ||||
|                     ggml_backend_buffer_tensor_alloc(buffer->backend_buffer, node); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     //printf("\n\n\n"); | ||||
| } | ||||
|  | ||||
| void ggml_graph_splits_allocate_tensors(struct ggml_graph_splits * splits) { | ||||
|     bool visited[GGML_MAX_SPLITS] = {false}; | ||||
|     for (int i = 0; i < splits->n_splits; i++) { | ||||
|         if (!visited[i]) { | ||||
|             struct ggml_graph_split * split = &splits->splits[i]; | ||||
|             struct ggml_backend * backend = split->dst_inputs[0]->backend; // not great | ||||
|             struct ggml_context * ctx = split->ctx; | ||||
|             struct ggml_cgraph * backend_graphs[GGML_MAX_SPLITS]; | ||||
|             int num_graphs = 0; | ||||
|             for (int j = i; j < splits->n_splits; j++) { | ||||
|                 if (splits->splits[j].dst_inputs[0]->backend == backend) { | ||||
|                     backend_graphs[num_graphs++] = splits->splits[j].graph; | ||||
|                 if (splits->splits[j].ctx == ctx) { | ||||
|                     backend_graphs[num_graphs] = splits->splits[j].graph; | ||||
|                     visited[j] = true; | ||||
|                     num_graphs++; | ||||
|                     // TODO: need to ensure that the output tensors are never freed | ||||
|                     // maybe this can be done automatically in ggml_graph_calc_compute_buffer_size by assuming that n_childs == 0 => output tensor | ||||
|                     // maybe this can be done automatically in ggml_graph_allocate_tensors_n by assuming that n_childs == 0 => output tensor | ||||
|                 } | ||||
|             } | ||||
|             ggml_graph_allocate_tensors_n(backend_graphs, num_graphs); | ||||
|             //printf("allocating tensors for %s [%d graphs/%d splits]\n", ggml_backend_name(ggml_get_buffer(ctx)->backend_buffer->backend), num_graphs, splits->n_splits); | ||||
|             ggml_graph_allocate_tensors_n(backend_graphs, num_graphs, ctx); | ||||
|         } | ||||
|     } | ||||
|     //printf("done allocating tensors\n"); | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -126,9 +126,10 @@ extern "C" { | ||||
|  | ||||
|     struct ggml_graph_split { | ||||
|         char name[GGML_MAX_NAME]; | ||||
|         struct ggml_tensor * src_inputs[GGML_MAX_SPLIT_INPUTS + 1]; | ||||
|         struct ggml_tensor * dst_inputs[GGML_MAX_SPLIT_INPUTS + 1]; | ||||
|         struct ggml_cgraph * graph; | ||||
|         struct ggml_context * ctx; | ||||
|         struct ggml_tensor  * src_inputs[GGML_MAX_SPLIT_INPUTS + 1]; | ||||
|         struct ggml_tensor  * dst_inputs[GGML_MAX_SPLIT_INPUTS + 1]; | ||||
|         struct ggml_cgraph  * graph; | ||||
|     }; | ||||
|  | ||||
|     // TODO: this shouldn't be fixed size, allocate from ggml_context | ||||
| @@ -153,8 +154,8 @@ extern "C" { | ||||
|     GGML_API void ggml_graph_splits_compute(struct ggml_graph_splits * splits); | ||||
|  | ||||
|     // graph tensor allocator | ||||
|     GGML_API void ggml_graph_allocate_tensors(struct ggml_cgraph * graph); | ||||
|     GGML_API void ggml_graph_allocate_tensors_n(struct ggml_cgraph ** graphs, int n_graphs); | ||||
|     GGML_API void ggml_graph_allocate_tensors(struct ggml_cgraph * graph, struct ggml_context * ctx); | ||||
|     GGML_API void ggml_graph_allocate_tensors_n(struct ggml_cgraph ** graphs, int n_graphs, struct ggml_context * ctx); | ||||
|     GGML_API void ggml_graph_splits_allocate_tensors(struct ggml_graph_splits * splits); | ||||
|  | ||||
| #ifdef  __cplusplus | ||||
|   | ||||
| @@ -1752,6 +1752,8 @@ static void ggml_backend_cuda_get_tensor_async(ggml_backend * backend, const ggm | ||||
|  | ||||
|     //ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; | ||||
|  | ||||
|     //printf("get tensor %s %p\n", tensor->name, tensor->data); | ||||
|  | ||||
|     CUDA_CHECK(cudaMemcpyAsync(data, (const char*)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStream_main)); | ||||
|  | ||||
|     UNUSED(backend); | ||||
|   | ||||
							
								
								
									
										46
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										46
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -3936,7 +3936,7 @@ struct ggml_context { | ||||
|  | ||||
|     struct ggml_buffer * buffer; | ||||
|  | ||||
|     bool   no_alloc; | ||||
|     enum ggml_alloc_mode alloc_mode; | ||||
|  | ||||
|     int    n_objects; | ||||
|  | ||||
| @@ -4292,7 +4292,7 @@ static inline int ggml_up(int n, int m) { | ||||
| struct ggml_init_params ggml_init_params_default(void) { | ||||
|     struct ggml_init_params default_params = { | ||||
|         /*.buffer       =*/ NULL, | ||||
|         /*.no_alloc     =*/ false, | ||||
|         /*.alloc_mode   =*/ GGML_ALLOC_IMMEDIATE, | ||||
|         /*.compute_type =*/ GGML_TYPE_F32 | ||||
|     }; | ||||
|     return default_params; | ||||
| @@ -4386,7 +4386,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { | ||||
|         /*.mem_size           =*/ params.buffer->mem_size, | ||||
|         /*.mem_buffer         =*/ params.buffer->mem_buffer, | ||||
|         /*.buffer             =*/ params.buffer, | ||||
|         /*.no_alloc           =*/ params.no_alloc, | ||||
|         /*.alloc_mode         =*/ params.alloc_mode, | ||||
|         /*.n_objects          =*/ 0, | ||||
|         /*.objects_begin      =*/ NULL, | ||||
|         /*.objects_end        =*/ NULL, | ||||
| @@ -4435,8 +4435,8 @@ size_t ggml_used_mem(const struct ggml_context * ctx) { | ||||
|     return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; | ||||
| } | ||||
|  | ||||
| void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) { | ||||
|     ctx->no_alloc = no_alloc; | ||||
| void ggml_set_alloc_mode(struct ggml_context * ctx, enum ggml_alloc_mode alloc_mode) { | ||||
|     ctx->alloc_mode = alloc_mode; | ||||
| } | ||||
|  | ||||
| void * ggml_get_mem_buffer(const struct ggml_context * ctx) { | ||||
| @@ -4467,8 +4467,8 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { | ||||
|     return max_size; | ||||
| } | ||||
|  | ||||
| struct ggml_backend * ggml_get_ctx_backend(struct ggml_context * ctx) { | ||||
|     return ctx->buffer->backend_buffer->backend; | ||||
| struct ggml_buffer * ggml_get_buffer(const struct ggml_context * ctx) { | ||||
|     return ctx->buffer; | ||||
| } | ||||
|  | ||||
| //////////////////////////////////////////////////////////////////////////////// | ||||
| @@ -4520,7 +4520,7 @@ struct ggml_tensor * ggml_new_tensor_impl( | ||||
|     ggml_assert_aligned(result); | ||||
|  | ||||
|     *result = (struct ggml_tensor) { | ||||
|         /*.backend      =*/ ggml_get_ctx_backend(ctx), | ||||
|         /*.backend      =*/ ctx->buffer->backend_buffer->backend, | ||||
|         /*.type         =*/ type, | ||||
|         /*.n_dims       =*/ n_dims, | ||||
|         /*.ne           =*/ { 1, 1, 1, 1 }, | ||||
| @@ -4537,7 +4537,7 @@ struct ggml_tensor * ggml_new_tensor_impl( | ||||
|         /*.data         =*/ data, | ||||
|         /*.name         =*/ { 0 }, | ||||
|         /*.extra        =*/ NULL, | ||||
|         /*.pad          =*/ { 0 }, | ||||
|         /*.padding      =*/ { 0 }, | ||||
|     }; | ||||
|  | ||||
|     for (int i = 0; i < n_dims; i++) { | ||||
| @@ -4550,14 +4550,10 @@ struct ggml_tensor * ggml_new_tensor_impl( | ||||
|         result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; | ||||
|     } | ||||
|  | ||||
|     if (data == NULL && !ctx->no_alloc) { | ||||
|          ggml_backend_buffer_tensor_alloc(ctx->buffer->backend_buffer, result); | ||||
|     if (data == NULL && ctx->alloc_mode == GGML_ALLOC_IMMEDIATE) { | ||||
|         ggml_backend_buffer_tensor_alloc(ctx->buffer->backend_buffer, result); | ||||
|     } | ||||
|  | ||||
|     // TODO: this should not be needed as long as we don't rely on aligned SIMD loads | ||||
|     //ggml_assert_aligned(result->data); | ||||
|  | ||||
|  | ||||
|     ctx->n_objects++; | ||||
|  | ||||
|     return result; | ||||
| @@ -6387,7 +6383,7 @@ struct ggml_tensor * ggml_view_1d( | ||||
|         is_node = true; | ||||
|     } | ||||
|  | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, a->data ? (char *) a->data + offset : NULL); | ||||
|     ggml_format_name(result, "%s (view)", a->name); | ||||
|  | ||||
|     ggml_set_op_params(result, &offset, sizeof(offset)); | ||||
| @@ -6418,7 +6414,7 @@ struct ggml_tensor * ggml_view_2d( | ||||
|  | ||||
|     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; | ||||
|  | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data ? (char *) a->data + offset : NULL); | ||||
|     ggml_format_name(result, "%s (view)", a->name); | ||||
|  | ||||
|     ggml_set_op_params(result, &offset, sizeof(offset)); | ||||
| @@ -6455,7 +6451,7 @@ struct ggml_tensor * ggml_view_3d( | ||||
|  | ||||
|     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 }; | ||||
|  | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset); | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data ? (char *) a->data + offset : NULL); | ||||
|     ggml_format_name(result, "%s (view)", a->name); | ||||
|  | ||||
|     ggml_set_op_params(result, &offset, sizeof(offset)); | ||||
| @@ -6494,7 +6490,7 @@ struct ggml_tensor * ggml_view_4d( | ||||
|  | ||||
|     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 }; | ||||
|  | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset); | ||||
|     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data ? (char *) a->data + offset : NULL); | ||||
|     ggml_format_name(result, "%s (view)", a->name); | ||||
|  | ||||
|     ggml_set_op_params(result, &offset, sizeof(offset)); | ||||
| @@ -6885,6 +6881,18 @@ struct ggml_tensor * ggml_rope_inplace( | ||||
|     return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true); | ||||
| } | ||||
|  | ||||
| struct ggml_tensor * ggml_rope_custom( | ||||
|         struct ggml_context * ctx, | ||||
|         struct ggml_tensor  * a, | ||||
|         int                   n_past, | ||||
|         int                   n_dims, | ||||
|         int                   mode, | ||||
|         float                 freq_base, | ||||
|         float                 freq_scale, | ||||
|         int                   n_ctx) { | ||||
|     return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, false); | ||||
| } | ||||
|  | ||||
| struct ggml_tensor * ggml_rope_custom_inplace( | ||||
|         struct ggml_context * ctx, | ||||
|         struct ggml_tensor  * a, | ||||
|   | ||||
							
								
								
									
										37
									
								
								ggml.h
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								ggml.h
									
									
									
									
									
								
							| @@ -474,24 +474,18 @@ extern "C" { | ||||
|         int64_t perf_time_us; | ||||
|     }; | ||||
|  | ||||
|     /* | ||||
|     TODO | ||||
|     enum ggml_alloc_mode { | ||||
|         GGML_ALLOC_IMMEDIATE, | ||||
|         GGML_ALLOC_NONE, | ||||
|         GGML_ALLOC_COMPUTE_SEQ, | ||||
|         GGML_ALLOC_COMPUTE_PAR, | ||||
|         GGML_ALLOC_NONE,            // do not allocate tensors | ||||
|         GGML_ALLOC_IMMEDIATE,       // allocate tensors immediately | ||||
|         GGML_ALLOC_COMPUTE_SEQ,     // delay allocation until graph build time, allocate tensors for sequential graph computation | ||||
|         //GGML_ALLOC_COMPUTE_PAR,     // allocate tensors for parallel graph computation | ||||
|     }; | ||||
|     */ | ||||
|  | ||||
|     // context parameters | ||||
|     struct ggml_init_params { | ||||
|         struct ggml_buffer * buffer; | ||||
|  | ||||
|         bool   no_alloc;   // don't allocate memory for the tensor data | ||||
|         //enum ggml_alloc_mode alloc_mode; // TODO: replace the above with this | ||||
|  | ||||
|         enum ggml_type compute_type;         // type of intermediate results | ||||
|         enum ggml_alloc_mode alloc_mode;   // tensor allocation mode | ||||
|         enum ggml_type       compute_type; // type of intermediate results | ||||
|     }; | ||||
|  | ||||
|     // task types | ||||
| @@ -559,15 +553,15 @@ extern "C" { | ||||
|     GGML_API struct ggml_context *   ggml_init(struct ggml_init_params params); | ||||
|     GGML_API void                    ggml_free(struct ggml_context * ctx); | ||||
|  | ||||
|     GGML_API void    ggml_set_alloc_mode(struct ggml_context * ctx, enum ggml_alloc_mode mode); | ||||
|  | ||||
|     // TODO: update for ggml_buffer | ||||
|     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx); | ||||
|  | ||||
|     GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); | ||||
|  | ||||
|     GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx); | ||||
|     GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx); | ||||
|     GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx); | ||||
|  | ||||
|     GGML_API struct ggml_backend * ggml_get_ctx_backend(struct ggml_context * ctx); | ||||
|     GGML_API struct ggml_buffer * ggml_get_buffer(const struct ggml_context * ctx); | ||||
|  | ||||
|     GGML_API struct ggml_tensor * ggml_new_tensor( | ||||
|             struct ggml_context * ctx, | ||||
| @@ -1130,6 +1124,17 @@ extern "C" { | ||||
|             int                   mode, | ||||
|             int                   n_ctx); | ||||
|  | ||||
|     // custom RoPE | ||||
|     GGML_API struct ggml_tensor * ggml_rope_custom( | ||||
|             struct ggml_context * ctx, | ||||
|             struct ggml_tensor  * a, | ||||
|             int                   n_past, | ||||
|             int                   n_dims, | ||||
|             int                   mode, | ||||
|             float                 freq_base, | ||||
|             float                 freq_scale, | ||||
|             int                   n_ctx); | ||||
|  | ||||
|     // custom RoPE, in-place, returns view(a) | ||||
|     GGML_API struct ggml_tensor * ggml_rope_custom_inplace( | ||||
|             struct ggml_context * ctx, | ||||
|   | ||||
							
								
								
									
										40
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1008,7 +1008,9 @@ static void llama_model_load_internal( | ||||
|         backend_data.buf = ggml_buffer_alloc(backend, ctx_size, num_tensors); | ||||
|         struct ggml_init_params params = ggml_init_params_default(); | ||||
|         params.buffer   = backend_data.buf; | ||||
|         params.no_alloc = backend == model.backend_cpu && ml->use_mmap; | ||||
|         if (backend == model.backend_cpu && ml->use_mmap) { | ||||
|             params.alloc_mode = GGML_ALLOC_NONE; | ||||
|         } | ||||
|         backend_data.ctx = ggml_init(params); | ||||
|         if (!backend_data.ctx) { | ||||
|             throw std::runtime_error(format("ggml_init() failed for backend context")); | ||||
| @@ -1184,6 +1186,8 @@ static ggml_graph_splits llama_build_graph( | ||||
|     for (ggml_buffer * buf_compute : lctx.bufs_compute) { | ||||
|         struct ggml_init_params params = ggml_init_params_default(); | ||||
|         params.buffer = buf_compute; | ||||
|         params.alloc_mode = GGML_ALLOC_COMPUTE_SEQ; | ||||
|         //params.alloc_mode = GGML_ALLOC_IMMEDIATE; | ||||
|         params.compute_type = compute_type; | ||||
|         ggml_context * ctx_buf = ggml_init(params); | ||||
|         ctxs.push_back(ctx_buf); | ||||
| @@ -1198,15 +1202,19 @@ static ggml_graph_splits llama_build_graph( | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     bool measuring = lctx.bufs_compute[0]->backend_buffer->measure; | ||||
|  | ||||
|     struct ggml_tensor * inpL; | ||||
|  | ||||
|     // reuse the scale tensor for all layers since it requires a memory transfer | ||||
|     //struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head)); | ||||
|     // struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head)); | ||||
|     // TODO: this shouldn't be necessary | ||||
|     bool measuring = lctx.bufs_compute[0]->backend_buffer->measure; | ||||
|     struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx_kv, GGML_TYPE_F32, 1); | ||||
|     if (!measuring) { | ||||
|         // this should be automatic | ||||
|         if (KQ_scale->data == NULL) { | ||||
|             ggml_backend_buffer_tensor_alloc(ggml_get_buffer(ctx_kv)->backend_buffer, KQ_scale); | ||||
|         } | ||||
|         ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); | ||||
|     } | ||||
|     ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); | ||||
| @@ -1254,10 +1262,10 @@ static ggml_graph_splits llama_build_graph( | ||||
|             struct ggml_tensor * tmpv = ggml_mul_mat(ctx_l, model.layers[il].wv, cur); | ||||
|             ggml_set_name(tmpv, "tmpv"); | ||||
|  | ||||
|             struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx_l, ggml_reshape_3d(ctx_l, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); | ||||
|             struct ggml_tensor * Kcur = ggml_rope_custom(ctx_l, ggml_reshape_3d(ctx_l, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); | ||||
|             ggml_set_name(Kcur, "Kcur"); | ||||
|  | ||||
|             struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx_l, ggml_reshape_3d(ctx_l, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); | ||||
|             struct ggml_tensor * Qcur = ggml_rope_custom(ctx_l, ggml_reshape_3d(ctx_l, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); | ||||
|             ggml_set_name(Qcur, "Qcur"); | ||||
|  | ||||
|             struct ggml_tensor * Vcur = ggml_transpose(ctx_l, ggml_reshape_2d(ctx_l, tmpv, n_embd, N)); | ||||
| @@ -1310,15 +1318,15 @@ static ggml_graph_splits llama_build_graph( | ||||
|  | ||||
|             // KQ_scaled = KQ / sqrt(n_embd/n_head) | ||||
|             // KQ_scaled shape [n_past + N, N, n_head, 1] | ||||
|             struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx_kv, KQ, KQ_scale); | ||||
|             struct ggml_tensor * KQ_scaled = ggml_scale(ctx_kv, KQ, KQ_scale); | ||||
|             ggml_set_name(KQ_scaled, "KQ_scaled"); | ||||
|  | ||||
|             // KQ_masked = mask_past(KQ_scaled) | ||||
|             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx_kv, KQ_scaled, n_past); | ||||
|             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx_kv, KQ_scaled, n_past); | ||||
|             ggml_set_name(KQ_masked, "KQ_masked"); | ||||
|  | ||||
|             // KQ = soft_max(KQ_masked) | ||||
|             struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx_kv, KQ_masked); | ||||
|             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx_kv, KQ_masked); | ||||
|             ggml_set_name(KQ_soft_max, "KQ_soft_max"); | ||||
|  | ||||
|             // split cached V into n_head heads | ||||
| @@ -1349,10 +1357,11 @@ static ggml_graph_splits llama_build_graph( | ||||
|  | ||||
|             // cur = KQV_merged.contiguous().view(n_embd, N) | ||||
|             cur = ggml_cpy(ctx_l, | ||||
|                     KQV_merged, | ||||
|                     //ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); | ||||
|                     //ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_embd, N)); | ||||
|                     ggml_new_tensor_2d(ctx_l, compute_type, n_embd, N)); // support both automatically? | ||||
|                     KQV_merged, ggml_set_name( | ||||
|                     //ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), | ||||
|                     //ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_embd, N), | ||||
|                     ggml_new_tensor_2d(ctx_l, compute_type, n_embd, N), // support both automatically? | ||||
|                     "KQV_merged_contiguous_dst")); | ||||
|             ggml_set_name(cur, "KQV_merged_contiguous"); | ||||
|  | ||||
|             // projection (no bias) | ||||
| @@ -2676,17 +2685,16 @@ struct llama_context * llama_new_context_with_model( | ||||
|         int n_past = hparams.n_ctx - n_tokens; | ||||
|         /*ggml_graph_splits splits =*/ llama_build_graph(*ctx, n_tokens, n_past); | ||||
|  | ||||
|         fprintf(stderr, "%s: compute ctx sizes:\n", __func__); | ||||
|         fprintf(stderr, "%s: compute buffer sizes:\n", __func__); | ||||
|         for (size_t i = 0; i < ctx->bufs_compute.size(); ++i) { | ||||
|             ggml_buffer * buf = ctx->bufs_compute[i]; | ||||
|             ggml_backend * backend = buf->backend_buffer->backend; | ||||
|             size_t size = buf->backend_buffer->max_size; | ||||
|             fprintf(stderr, "%8s = %7.2f MB\n", ggml_backend_name(backend), size / 1024.0 / 1024.0); | ||||
|             ggml_buffer_free(buf); | ||||
|  | ||||
|             // reallocate with the correct size | ||||
|             buf = ggml_buffer_alloc(buf->backend_buffer->backend, size, 2048); | ||||
|             ctx->bufs_compute[i] = buf; | ||||
|             ggml_buffer_free(buf); | ||||
|             ctx->bufs_compute[i] = ggml_buffer_alloc(buf->backend_buffer->backend, size, 2048); | ||||
|         } | ||||
|  | ||||
|         // TODO: use pinned memory for faster host-device transfers | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren