mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cuda : enable CUDA Graph on CUDA Toolkit < 12.x (#12394)
* Enable CUDA Graph on CTK < 12.x `cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x. * Fix compilation errors with MUSA * Disable CUDA Graph for MUSA
This commit is contained in:
		| @@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu { | ||||
| }; | ||||
|  | ||||
|  | ||||
| #if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) | ||||
| #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) | ||||
| #define USE_CUDA_GRAPH | ||||
| #endif | ||||
|  | ||||
|   | ||||
| @@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, | ||||
|  | ||||
| static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { | ||||
|  | ||||
| #if CUDART_VERSION >= 12000 | ||||
|     cudaGraphExecUpdateResultInfo result_info; | ||||
| #ifdef __HIP_PLATFORM_AMD__ | ||||
|     hipGraphNode_t errorNode; | ||||
|     hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); | ||||
| #else | ||||
|     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); | ||||
| #endif | ||||
| #else | ||||
|     cudaGraphNode_t errorNode; | ||||
|     cudaGraphExecUpdateResult result_info; | ||||
|     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); | ||||
| #endif // CUDART_VERSION >= 12000 | ||||
|  | ||||
|     if (stat == cudaErrorGraphExecUpdateFailure) { | ||||
| #ifndef NDEBUG | ||||
|         GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); | ||||
|   | ||||
							
								
								
									
										2
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							| @@ -112,7 +112,7 @@ | ||||
| #define cudaGraphExecDestroy hipGraphExecDestroy | ||||
| #define cudaGraphLaunch hipGraphLaunch | ||||
| #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure | ||||
| #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult | ||||
| #define cudaGraphExecUpdateResult hipGraphExecUpdateResult | ||||
| #define cudaGraphNodeType hipGraphNodeType | ||||
| #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel | ||||
| #define cudaGraphInstantiate hipGraphInstantiate | ||||
|   | ||||
							
								
								
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							| @@ -119,7 +119,7 @@ | ||||
| #define cudaGraphExecDestroy musaGraphExecDestroy | ||||
| #define cudaGraphExec_t musaGraphExec_t | ||||
| #define cudaGraphExecUpdate musaGraphExecUpdate | ||||
| #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult | ||||
| #define cudaGraphExecUpdateResult musaGraphExecUpdateResult | ||||
| #define cudaGraphGetNodes musaGraphGetNodes | ||||
| #define cudaGraphInstantiate musaGraphInstantiate | ||||
| #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams | ||||
| @@ -132,6 +132,7 @@ | ||||
| #define cudaGraph_t musaGraph_t | ||||
| #define cudaKernelNodeParams musaKernelNodeParams | ||||
| #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed | ||||
| #define cudaStreamBeginCapture musaStreamBeginCapture | ||||
| #define cudaStreamEndCapture musaStreamEndCapture | ||||
|  | ||||
| typedef mt_bfloat16 nv_bfloat16; | ||||
|   | ||||
| @@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND) | ||||
|     add_compile_definitions(GGML_USE_MUSA) | ||||
|     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) | ||||
|  | ||||
|     if (GGML_CUDA_GRAPHS) | ||||
|         add_compile_definitions(GGML_CUDA_USE_GRAPHS) | ||||
|     endif() | ||||
|  | ||||
|     if (GGML_CUDA_FORCE_MMQ) | ||||
|         add_compile_definitions(GGML_CUDA_FORCE_MMQ) | ||||
|     endif() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gaurav Garg
					Gaurav Garg