mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	rpc : resource management rework (#7562)
* rpc : resource management rework * address review comments
This commit is contained in:
		 Radoslav Gerganov
					Radoslav Gerganov
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							ee3dff6b8e
						
					
				
				
					commit
					2b737caae1
				
			
							
								
								
									
										131
									
								
								ggml-rpc.cpp
									
									
									
									
									
								
							
							
						
						
									
										131
									
								
								ggml-rpc.cpp
									
									
									
									
									
								
							| @@ -6,6 +6,7 @@ | |||||||
| #include <string> | #include <string> | ||||||
| #include <vector> | #include <vector> | ||||||
| #include <memory> | #include <memory> | ||||||
|  | #include <mutex> | ||||||
| #include <unordered_map> | #include <unordered_map> | ||||||
| #include <unordered_set> | #include <unordered_set> | ||||||
| #ifdef _WIN32 | #ifdef _WIN32 | ||||||
| @@ -47,6 +48,7 @@ struct socket_t { | |||||||
|     sockfd_t fd; |     sockfd_t fd; | ||||||
|     socket_t(sockfd_t fd) : fd(fd) {} |     socket_t(sockfd_t fd) : fd(fd) {} | ||||||
|     ~socket_t() { |     ~socket_t() { | ||||||
|  |         GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd); | ||||||
| #ifdef _WIN32 | #ifdef _WIN32 | ||||||
|         closesocket(this->fd); |         closesocket(this->fd); | ||||||
| #else | #else | ||||||
| @@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() { | |||||||
| } | } | ||||||
|  |  | ||||||
| struct ggml_backend_rpc_buffer_type_context { | struct ggml_backend_rpc_buffer_type_context { | ||||||
|     std::shared_ptr<socket_t> sock; |     std::string endpoint; | ||||||
|     std::string name; |     std::string name; | ||||||
|     size_t alignment; |     size_t alignment; | ||||||
|     size_t max_size; |     size_t max_size; | ||||||
| @@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context { | |||||||
| struct ggml_backend_rpc_context { | struct ggml_backend_rpc_context { | ||||||
|     std::string endpoint; |     std::string endpoint; | ||||||
|     std::string name; |     std::string name; | ||||||
|     std::shared_ptr<socket_t> sock; |  | ||||||
|     ggml_backend_buffer_type_t buft; |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct ggml_backend_rpc_buffer_context { | struct ggml_backend_rpc_buffer_context { | ||||||
| @@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| static bool parse_endpoint(const char * endpoint, std::string & host, int & port) { | static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { | ||||||
|     std::string str(endpoint); |     size_t pos = endpoint.find(':'); | ||||||
|     size_t pos = str.find(':'); |  | ||||||
|     if (pos == std::string::npos) { |     if (pos == std::string::npos) { | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|     host = str.substr(0, pos); |     host = endpoint.substr(0, pos); | ||||||
|     port = std::stoi(str.substr(pos + 1)); |     port = std::stoi(endpoint.substr(pos + 1)); | ||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm | |||||||
|  |  | ||||||
| // RPC client-side implementation | // RPC client-side implementation | ||||||
|  |  | ||||||
|  | static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) { | ||||||
|  |     static std::mutex mutex; | ||||||
|  |     std::lock_guard<std::mutex> lock(mutex); | ||||||
|  |     static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets; | ||||||
|  |     static bool initialized = false; | ||||||
|  |  | ||||||
|  |     auto it = sockets.find(endpoint); | ||||||
|  |     if (it != sockets.end()) { | ||||||
|  |         if (auto sock = it->second.lock()) { | ||||||
|  |             return sock; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     std::string host; | ||||||
|  |     int port; | ||||||
|  |     if (!parse_endpoint(endpoint, host, port)) { | ||||||
|  |         return nullptr; | ||||||
|  |     } | ||||||
|  | #ifdef _WIN32 | ||||||
|  |     if (!initialized) { | ||||||
|  |         WSADATA wsaData; | ||||||
|  |         int res = WSAStartup(MAKEWORD(2, 2), &wsaData); | ||||||
|  |         if (res != 0) { | ||||||
|  |             return nullptr; | ||||||
|  |         } | ||||||
|  |         initialized = true; | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     UNUSED(initialized); | ||||||
|  | #endif | ||||||
|  |     auto sock = socket_connect(host.c_str(), port); | ||||||
|  |     if (sock == nullptr) { | ||||||
|  |         return nullptr; | ||||||
|  |     } | ||||||
|  |     GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); | ||||||
|  |     sockets[endpoint] = sock; | ||||||
|  |     return sock; | ||||||
|  | } | ||||||
|  |  | ||||||
| GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { | GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { | ||||||
|     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; |     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; | ||||||
|     return ctx->name.c_str(); |     return ctx->name.c_str(); | ||||||
| @@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer | |||||||
|     std::vector<uint8_t> input(input_size, 0); |     std::vector<uint8_t> input(input_size, 0); | ||||||
|     memcpy(input.data(), &size, sizeof(size)); |     memcpy(input.data(), &size, sizeof(size)); | ||||||
|     std::vector<uint8_t> output; |     std::vector<uint8_t> output; | ||||||
|     bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); |     auto sock = get_socket(buft_ctx->endpoint); | ||||||
|  |     bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); | ||||||
|     GGML_ASSERT(status); |     GGML_ASSERT(status); | ||||||
|     GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); |     GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); | ||||||
|     // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | |     // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | | ||||||
| @@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer | |||||||
|     if (remote_ptr != 0) { |     if (remote_ptr != 0) { | ||||||
|         ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, |         ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, | ||||||
|             ggml_backend_rpc_buffer_interface, |             ggml_backend_rpc_buffer_interface, | ||||||
|             new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, |             new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"}, | ||||||
|             remote_size); |             remote_size); | ||||||
|         return buffer; |         return buffer; | ||||||
|     } else { |     } else { | ||||||
| @@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend | |||||||
|     } |     } | ||||||
|     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; |     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; | ||||||
|     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; |     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; | ||||||
|     return buft_ctx->sock == rpc_ctx->sock; |     return buft_ctx->endpoint == rpc_ctx->endpoint; | ||||||
| } | } | ||||||
|  |  | ||||||
| static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { | static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { | ||||||
| @@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { | |||||||
|     /* .is_host          = */ NULL, |     /* .is_host          = */ NULL, | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
| GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { | GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { | ||||||
|     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; |     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; | ||||||
|  |  | ||||||
| @@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { | |||||||
|  |  | ||||||
| GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { | GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { | ||||||
|     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; |     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; | ||||||
|     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; |  | ||||||
|     delete buft_ctx; |  | ||||||
|     delete rpc_ctx->buft; |  | ||||||
|     delete rpc_ctx; |     delete rpc_ctx; | ||||||
|     delete backend; |     delete backend; | ||||||
| } | } | ||||||
|  |  | ||||||
| GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { | GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { | ||||||
|     ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; |     ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; | ||||||
|     return ctx->buft; |     return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); | ||||||
| } | } | ||||||
|  |  | ||||||
| GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { | GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { | ||||||
| @@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t | |||||||
|     std::vector<uint8_t> input; |     std::vector<uint8_t> input; | ||||||
|     serialize_graph(cgraph, input); |     serialize_graph(cgraph, input); | ||||||
|     std::vector<uint8_t> output; |     std::vector<uint8_t> output; | ||||||
|     bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); |     auto sock = get_socket(rpc_ctx->endpoint); | ||||||
|  |     bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); | ||||||
|     GGML_ASSERT(status); |     GGML_ASSERT(status); | ||||||
|     GGML_ASSERT(output.size() == 1); |     GGML_ASSERT(output.size() == 1); | ||||||
|     return (enum ggml_status)output[0]; |     return (enum ggml_status)output[0]; | ||||||
| @@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = { | |||||||
|     /* .event_synchronize       = */ NULL, |     /* .event_synchronize       = */ NULL, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static std::unordered_map<std::string, ggml_backend_t> instances; |  | ||||||
|  |  | ||||||
| GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { | GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { | ||||||
|     ggml_backend_t backend = ggml_backend_rpc_init(endpoint); |     static std::mutex mutex; | ||||||
|     return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; |     std::lock_guard<std::mutex> lock(mutex); | ||||||
| } |     // NOTE: buffer types are allocated and never freed; this is by design | ||||||
|  |     static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map; | ||||||
| GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { |     auto it = buft_map.find(endpoint); | ||||||
|     std::string endpoint_str(endpoint); |     if (it != buft_map.end()) { | ||||||
|     if (instances.find(endpoint_str) != instances.end()) { |         return it->second; | ||||||
|         return instances[endpoint_str]; |  | ||||||
|     } |     } | ||||||
| #ifdef _WIN32 |     auto sock = get_socket(endpoint); | ||||||
|     { |  | ||||||
|         WSADATA wsaData; |  | ||||||
|         int res = WSAStartup(MAKEWORD(2, 2), &wsaData); |  | ||||||
|         if (res != 0) { |  | ||||||
|             return nullptr; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| #endif |  | ||||||
|     fprintf(stderr, "Connecting to %s\n", endpoint); |  | ||||||
|     std::string host; |  | ||||||
|     int port; |  | ||||||
|     if (!parse_endpoint(endpoint, host, port)) { |  | ||||||
|         return nullptr; |  | ||||||
|     } |  | ||||||
|     auto sock = socket_connect(host.c_str(), port); |  | ||||||
|     if (sock == nullptr) { |     if (sock == nullptr) { | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     } |     } | ||||||
|     size_t alignment = get_alignment(sock); |     size_t alignment = get_alignment(sock); | ||||||
|     size_t max_size = get_max_size(sock); |     size_t max_size = get_max_size(sock); | ||||||
|     ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { |     ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { | ||||||
|         /* .sock   = */ sock, |         /* .endpoint  = */ endpoint, | ||||||
|         /* .name   = */ "RPC" + std::to_string(sock->fd), |         /* .name      = */ "RPC[" + std::string(endpoint) + "]", | ||||||
|         /* .alignment = */ alignment, |         /* .alignment = */ alignment, | ||||||
|         /* .max_size = */ max_size |         /* .max_size  = */ max_size | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { |     ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { | ||||||
|         /* .iface   = */ ggml_backend_rpc_buffer_type_interface, |         /* .iface   = */ ggml_backend_rpc_buffer_type_interface, | ||||||
|         /* .context = */ buft_ctx |         /* .context = */ buft_ctx | ||||||
|     }; |     }; | ||||||
|  |     buft_map[endpoint] = buft; | ||||||
|  |     return buft; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { | ||||||
|     ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { |     ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { | ||||||
|         /* .endpoint = */ endpoint, |         /* .endpoint  = */ endpoint, | ||||||
|         /* .name     = */ "RPC" + std::to_string(sock->fd), |         /* .name      = */ "RPC", | ||||||
|         /* .sock     = */ sock, |  | ||||||
|         /* .buft     = */ buft |  | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     instances[endpoint] = new ggml_backend { |     ggml_backend_t backend = new ggml_backend { | ||||||
|         /* .guid      = */ ggml_backend_rpc_guid(), |         /* .guid      = */ ggml_backend_rpc_guid(), | ||||||
|         /* .interface = */ ggml_backend_rpc_interface, |         /* .interface = */ ggml_backend_rpc_interface, | ||||||
|         /* .context   = */ ctx |         /* .context   = */ ctx | ||||||
|     }; |     }; | ||||||
|  |     return backend; | ||||||
|     return instances[endpoint]; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { | GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { | ||||||
| @@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f | |||||||
| } | } | ||||||
|  |  | ||||||
| GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { | GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { | ||||||
|     ggml_backend_t backend = ggml_backend_rpc_init(endpoint); |     auto sock = get_socket(endpoint); | ||||||
|     if (backend == nullptr) { |     if (sock == nullptr) { | ||||||
|         *free = 0; |         *free = 0; | ||||||
|         *total = 0; |         *total = 0; | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; |     get_device_memory(sock, free, total); | ||||||
|     get_device_memory(ctx->sock, free, total); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // RPC server-side implementation | // RPC server-side implementation | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user