mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	cuda : ignore peer access already enabled errors (#5597)
* cuda : ignore peer access already enabled errors * fix hip
This commit is contained in:
		
							
								
								
									
										22
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -54,6 +54,8 @@ | |||||||
| #define cudaDeviceProp hipDeviceProp_t | #define cudaDeviceProp hipDeviceProp_t | ||||||
| #define cudaDeviceSynchronize hipDeviceSynchronize | #define cudaDeviceSynchronize hipDeviceSynchronize | ||||||
| #define cudaError_t hipError_t | #define cudaError_t hipError_t | ||||||
|  | #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled | ||||||
|  | #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled | ||||||
| #define cudaEventCreateWithFlags hipEventCreateWithFlags | #define cudaEventCreateWithFlags hipEventCreateWithFlags | ||||||
| #define cudaEventDisableTiming hipEventDisableTiming | #define cudaEventDisableTiming hipEventDisableTiming | ||||||
| #define cudaEventRecord hipEventRecord | #define cudaEventRecord hipEventRecord | ||||||
| @@ -9325,9 +9327,15 @@ static void ggml_cuda_set_peer_access(const int n_tokens) { | |||||||
|             CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); |             CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); | ||||||
|             if (can_access_peer) { |             if (can_access_peer) { | ||||||
|                 if (enable_peer_access) { |                 if (enable_peer_access) { | ||||||
|                     CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); |                     cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); | ||||||
|  |                     if (err != cudaErrorPeerAccessAlreadyEnabled) { | ||||||
|  |                         CUDA_CHECK(err); | ||||||
|  |                     } | ||||||
|                 } else { |                 } else { | ||||||
|                     CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other)); |                     cudaError_t err = cudaDeviceDisablePeerAccess(id_other); | ||||||
|  |                     if (err != cudaErrorPeerAccessNotEnabled) { | ||||||
|  |                         CUDA_CHECK(err); | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -10999,10 +11007,10 @@ GGML_CALL static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backe | |||||||
|     UNUSED(buffer); |     UNUSED(buffer); | ||||||
| } | } | ||||||
|  |  | ||||||
| // unused at the moment | static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) { | ||||||
| //static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) { |     return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name; | ||||||
| //    return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name; |     UNUSED(ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds | ||||||
| //} | } | ||||||
|  |  | ||||||
| GGML_CALL static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { | GGML_CALL static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { | ||||||
|     ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; |     ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; | ||||||
| @@ -11390,7 +11398,7 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg | |||||||
|         for (int j = 0; j < GGML_MAX_SRC; j++) { |         for (int j = 0; j < GGML_MAX_SRC; j++) { | ||||||
|             if (node->src[j] != nullptr) { |             if (node->src[j] != nullptr) { | ||||||
|                 assert(node->src[j]->backend == GGML_BACKEND_GPU || node->src[j]->backend == GGML_BACKEND_GPU_SPLIT); |                 assert(node->src[j]->backend == GGML_BACKEND_GPU || node->src[j]->backend == GGML_BACKEND_GPU_SPLIT); | ||||||
|                 assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); |                 assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); | ||||||
|                 assert(node->src[j]->extra != nullptr); |                 assert(node->src[j]->extra != nullptr); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren