mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cuda : enable CUDA Graph on CUDA Toolkit < 12.x (#12394)
* Enable CUDA Graph on CTK < 12.x `cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x. * Fix compilation errors with MUSA * Disable CUDA Graph for MUSA
This commit is contained in:
		| @@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu { | |||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
| #if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) | #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) | ||||||
| #define USE_CUDA_GRAPH | #define USE_CUDA_GRAPH | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, | |||||||
|  |  | ||||||
| static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { | static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { | ||||||
|  |  | ||||||
|  | #if CUDART_VERSION >= 12000 | ||||||
|     cudaGraphExecUpdateResultInfo result_info; |     cudaGraphExecUpdateResultInfo result_info; | ||||||
| #ifdef __HIP_PLATFORM_AMD__ |  | ||||||
|     hipGraphNode_t errorNode; |  | ||||||
|     hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); |  | ||||||
| #else |  | ||||||
|     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); |     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); | ||||||
| #endif | #else | ||||||
|  |     cudaGraphNode_t errorNode; | ||||||
|  |     cudaGraphExecUpdateResult result_info; | ||||||
|  |     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); | ||||||
|  | #endif // CUDART_VERSION >= 12000 | ||||||
|  |  | ||||||
|     if (stat == cudaErrorGraphExecUpdateFailure) { |     if (stat == cudaErrorGraphExecUpdateFailure) { | ||||||
| #ifndef NDEBUG | #ifndef NDEBUG | ||||||
|         GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); |         GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							| @@ -112,7 +112,7 @@ | |||||||
| #define cudaGraphExecDestroy hipGraphExecDestroy | #define cudaGraphExecDestroy hipGraphExecDestroy | ||||||
| #define cudaGraphLaunch hipGraphLaunch | #define cudaGraphLaunch hipGraphLaunch | ||||||
| #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure | #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure | ||||||
| #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult | #define cudaGraphExecUpdateResult hipGraphExecUpdateResult | ||||||
| #define cudaGraphNodeType hipGraphNodeType | #define cudaGraphNodeType hipGraphNodeType | ||||||
| #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel | #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel | ||||||
| #define cudaGraphInstantiate hipGraphInstantiate | #define cudaGraphInstantiate hipGraphInstantiate | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							| @@ -119,7 +119,7 @@ | |||||||
| #define cudaGraphExecDestroy musaGraphExecDestroy | #define cudaGraphExecDestroy musaGraphExecDestroy | ||||||
| #define cudaGraphExec_t musaGraphExec_t | #define cudaGraphExec_t musaGraphExec_t | ||||||
| #define cudaGraphExecUpdate musaGraphExecUpdate | #define cudaGraphExecUpdate musaGraphExecUpdate | ||||||
| #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult | #define cudaGraphExecUpdateResult musaGraphExecUpdateResult | ||||||
| #define cudaGraphGetNodes musaGraphGetNodes | #define cudaGraphGetNodes musaGraphGetNodes | ||||||
| #define cudaGraphInstantiate musaGraphInstantiate | #define cudaGraphInstantiate musaGraphInstantiate | ||||||
| #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams | #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams | ||||||
| @@ -132,6 +132,7 @@ | |||||||
| #define cudaGraph_t musaGraph_t | #define cudaGraph_t musaGraph_t | ||||||
| #define cudaKernelNodeParams musaKernelNodeParams | #define cudaKernelNodeParams musaKernelNodeParams | ||||||
| #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed | #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed | ||||||
|  | #define cudaStreamBeginCapture musaStreamBeginCapture | ||||||
| #define cudaStreamEndCapture musaStreamEndCapture | #define cudaStreamEndCapture musaStreamEndCapture | ||||||
|  |  | ||||||
| typedef mt_bfloat16 nv_bfloat16; | typedef mt_bfloat16 nv_bfloat16; | ||||||
|   | |||||||
| @@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND) | |||||||
|     add_compile_definitions(GGML_USE_MUSA) |     add_compile_definitions(GGML_USE_MUSA) | ||||||
|     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) |     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) | ||||||
|  |  | ||||||
|     if (GGML_CUDA_GRAPHS) |  | ||||||
|         add_compile_definitions(GGML_CUDA_USE_GRAPHS) |  | ||||||
|     endif() |  | ||||||
|  |  | ||||||
|     if (GGML_CUDA_FORCE_MMQ) |     if (GGML_CUDA_FORCE_MMQ) | ||||||
|         add_compile_definitions(GGML_CUDA_FORCE_MMQ) |         add_compile_definitions(GGML_CUDA_FORCE_MMQ) | ||||||
|     endif() |     endif() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Gaurav Garg
					Gaurav Garg