mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	rpc : resource management rework (#7562)
* rpc : resource management rework * address review comments
This commit is contained in:
		 Radoslav Gerganov
					Radoslav Gerganov
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							ee3dff6b8e
						
					
				
				
					commit
					2b737caae1
				
			
							
								
								
									
										127
									
								
								ggml-rpc.cpp
									
									
									
									
									
								
							
							
						
						
									
										127
									
								
								ggml-rpc.cpp
									
									
									
									
									
								
							| @@ -6,6 +6,7 @@ | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include <memory> | ||||
| #include <mutex> | ||||
| #include <unordered_map> | ||||
| #include <unordered_set> | ||||
| #ifdef _WIN32 | ||||
| @@ -47,6 +48,7 @@ struct socket_t { | ||||
|     sockfd_t fd; | ||||
|     socket_t(sockfd_t fd) : fd(fd) {} | ||||
|     ~socket_t() { | ||||
|         GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd); | ||||
| #ifdef _WIN32 | ||||
|         closesocket(this->fd); | ||||
| #else | ||||
| @@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() { | ||||
| } | ||||
|  | ||||
| struct ggml_backend_rpc_buffer_type_context { | ||||
|     std::shared_ptr<socket_t> sock; | ||||
|     std::string endpoint; | ||||
|     std::string name; | ||||
|     size_t alignment; | ||||
|     size_t max_size; | ||||
| @@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context { | ||||
| struct ggml_backend_rpc_context { | ||||
|     std::string endpoint; | ||||
|     std::string name; | ||||
|     std::shared_ptr<socket_t> sock; | ||||
|     ggml_backend_buffer_type_t buft; | ||||
| }; | ||||
|  | ||||
| struct ggml_backend_rpc_buffer_context { | ||||
| @@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| static bool parse_endpoint(const char * endpoint, std::string & host, int & port) { | ||||
|     std::string str(endpoint); | ||||
|     size_t pos = str.find(':'); | ||||
| static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { | ||||
|     size_t pos = endpoint.find(':'); | ||||
|     if (pos == std::string::npos) { | ||||
|         return false; | ||||
|     } | ||||
|     host = str.substr(0, pos); | ||||
|     port = std::stoi(str.substr(pos + 1)); | ||||
|     host = endpoint.substr(0, pos); | ||||
|     port = std::stoi(endpoint.substr(pos + 1)); | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| @@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm | ||||
|  | ||||
| // RPC client-side implementation | ||||
|  | ||||
| static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) { | ||||
|     static std::mutex mutex; | ||||
|     std::lock_guard<std::mutex> lock(mutex); | ||||
|     static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets; | ||||
|     static bool initialized = false; | ||||
|  | ||||
|     auto it = sockets.find(endpoint); | ||||
|     if (it != sockets.end()) { | ||||
|         if (auto sock = it->second.lock()) { | ||||
|             return sock; | ||||
|         } | ||||
|     } | ||||
|     std::string host; | ||||
|     int port; | ||||
|     if (!parse_endpoint(endpoint, host, port)) { | ||||
|         return nullptr; | ||||
|     } | ||||
| #ifdef _WIN32 | ||||
|     if (!initialized) { | ||||
|         WSADATA wsaData; | ||||
|         int res = WSAStartup(MAKEWORD(2, 2), &wsaData); | ||||
|         if (res != 0) { | ||||
|             return nullptr; | ||||
|         } | ||||
|         initialized = true; | ||||
|     } | ||||
| #else | ||||
|     UNUSED(initialized); | ||||
| #endif | ||||
|     auto sock = socket_connect(host.c_str(), port); | ||||
|     if (sock == nullptr) { | ||||
|         return nullptr; | ||||
|     } | ||||
|     GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); | ||||
|     sockets[endpoint] = sock; | ||||
|     return sock; | ||||
| } | ||||
|  | ||||
| GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { | ||||
|     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; | ||||
|     return ctx->name.c_str(); | ||||
| @@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer | ||||
|     std::vector<uint8_t> input(input_size, 0); | ||||
|     memcpy(input.data(), &size, sizeof(size)); | ||||
|     std::vector<uint8_t> output; | ||||
|     bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); | ||||
|     auto sock = get_socket(buft_ctx->endpoint); | ||||
|     bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); | ||||
|     GGML_ASSERT(status); | ||||
|     GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); | ||||
|     // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | | ||||
| @@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer | ||||
|     if (remote_ptr != 0) { | ||||
|         ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, | ||||
|             ggml_backend_rpc_buffer_interface, | ||||
|             new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, | ||||
|             new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"}, | ||||
|             remote_size); | ||||
|         return buffer; | ||||
|     } else { | ||||
| @@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend | ||||
|     } | ||||
|     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; | ||||
|     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; | ||||
|     return buft_ctx->sock == rpc_ctx->sock; | ||||
|     return buft_ctx->endpoint == rpc_ctx->endpoint; | ||||
| } | ||||
|  | ||||
| static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { | ||||
| @@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { | ||||
|     /* .is_host          = */ NULL, | ||||
| }; | ||||
|  | ||||
|  | ||||
| GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { | ||||
|     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; | ||||
|  | ||||
| @@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { | ||||
|  | ||||
| GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { | ||||
|     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; | ||||
|     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; | ||||
|     delete buft_ctx; | ||||
|     delete rpc_ctx->buft; | ||||
|     delete rpc_ctx; | ||||
|     delete backend; | ||||
| } | ||||
|  | ||||
| GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { | ||||
|     ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; | ||||
|     return ctx->buft; | ||||
|     return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); | ||||
| } | ||||
|  | ||||
| GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { | ||||
| @@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t | ||||
|     std::vector<uint8_t> input; | ||||
|     serialize_graph(cgraph, input); | ||||
|     std::vector<uint8_t> output; | ||||
|     bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); | ||||
|     auto sock = get_socket(rpc_ctx->endpoint); | ||||
|     bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); | ||||
|     GGML_ASSERT(status); | ||||
|     GGML_ASSERT(output.size() == 1); | ||||
|     return (enum ggml_status)output[0]; | ||||
| @@ -624,42 +659,24 @@ static ggml_backend_i ggml_backend_rpc_interface = { | ||||
|     /* .event_synchronize       = */ NULL, | ||||
| }; | ||||
|  | ||||
| static std::unordered_map<std::string, ggml_backend_t> instances; | ||||
|  | ||||
| GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { | ||||
|     ggml_backend_t backend = ggml_backend_rpc_init(endpoint); | ||||
|     return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; | ||||
|     static std::mutex mutex; | ||||
|     std::lock_guard<std::mutex> lock(mutex); | ||||
|     // NOTE: buffer types are allocated and never freed; this is by design | ||||
|     static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map; | ||||
|     auto it = buft_map.find(endpoint); | ||||
|     if (it != buft_map.end()) { | ||||
|         return it->second; | ||||
|     } | ||||
|  | ||||
| GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { | ||||
|     std::string endpoint_str(endpoint); | ||||
|     if (instances.find(endpoint_str) != instances.end()) { | ||||
|         return instances[endpoint_str]; | ||||
|     } | ||||
| #ifdef _WIN32 | ||||
|     { | ||||
|         WSADATA wsaData; | ||||
|         int res = WSAStartup(MAKEWORD(2, 2), &wsaData); | ||||
|         if (res != 0) { | ||||
|             return nullptr; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|     fprintf(stderr, "Connecting to %s\n", endpoint); | ||||
|     std::string host; | ||||
|     int port; | ||||
|     if (!parse_endpoint(endpoint, host, port)) { | ||||
|         return nullptr; | ||||
|     } | ||||
|     auto sock = socket_connect(host.c_str(), port); | ||||
|     auto sock = get_socket(endpoint); | ||||
|     if (sock == nullptr) { | ||||
|         return nullptr; | ||||
|     } | ||||
|     size_t alignment = get_alignment(sock); | ||||
|     size_t max_size = get_max_size(sock); | ||||
|     ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { | ||||
|         /* .sock   = */ sock, | ||||
|         /* .name   = */ "RPC" + std::to_string(sock->fd), | ||||
|         /* .endpoint  = */ endpoint, | ||||
|         /* .name      = */ "RPC[" + std::string(endpoint) + "]", | ||||
|         /* .alignment = */ alignment, | ||||
|         /* .max_size  = */ max_size | ||||
|     }; | ||||
| @@ -668,21 +685,22 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { | ||||
|         /* .iface   = */ ggml_backend_rpc_buffer_type_interface, | ||||
|         /* .context = */ buft_ctx | ||||
|     }; | ||||
|     buft_map[endpoint] = buft; | ||||
|     return buft; | ||||
| } | ||||
|  | ||||
| GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { | ||||
|     ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { | ||||
|         /* .endpoint  = */ endpoint, | ||||
|         /* .name     = */ "RPC" + std::to_string(sock->fd), | ||||
|         /* .sock     = */ sock, | ||||
|         /* .buft     = */ buft | ||||
|         /* .name      = */ "RPC", | ||||
|     }; | ||||
|  | ||||
|     instances[endpoint] = new ggml_backend { | ||||
|     ggml_backend_t backend = new ggml_backend { | ||||
|         /* .guid      = */ ggml_backend_rpc_guid(), | ||||
|         /* .interface = */ ggml_backend_rpc_interface, | ||||
|         /* .context   = */ ctx | ||||
|     }; | ||||
|  | ||||
|     return instances[endpoint]; | ||||
|     return backend; | ||||
| } | ||||
|  | ||||
| GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { | ||||
| @@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f | ||||
| } | ||||
|  | ||||
| GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { | ||||
|     ggml_backend_t backend = ggml_backend_rpc_init(endpoint); | ||||
|     if (backend == nullptr) { | ||||
|     auto sock = get_socket(endpoint); | ||||
|     if (sock == nullptr) { | ||||
|         *free = 0; | ||||
|         *total = 0; | ||||
|         return; | ||||
|     } | ||||
|     ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; | ||||
|     get_device_memory(ctx->sock, free, total); | ||||
|     get_device_memory(sock, free, total); | ||||
| } | ||||
|  | ||||
| // RPC server-side implementation | ||||
|   | ||||
		Reference in New Issue
	
	Block a user