mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	rpc : report actual free memory (#16616)
* rpc : report actual free memory Start reporting the free memory on every device instead of using fixed values. Now llama-cli users can get a nice memory breakdown when using RPC devices. * drop --mem in rpc-server
This commit is contained in:
		 Radoslav Gerganov
					Radoslav Gerganov
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							3d4e86bbeb
						
					
				
				
					commit
					41386cf365
				
			| @@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c | |||||||
| GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); | GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); | ||||||
|  |  | ||||||
| GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, | GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, | ||||||
|                                                     size_t n_threads, size_t n_devices, |                                                     size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices); | ||||||
|                                                     ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); |  | ||||||
|  |  | ||||||
| GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); | ||||||
| GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); | ||||||
|   | |||||||
| @@ -939,6 +939,7 @@ public: | |||||||
|     bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response); |     bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response); | ||||||
|     bool init_tensor(const rpc_msg_init_tensor_req & request); |     bool init_tensor(const rpc_msg_init_tensor_req & request); | ||||||
|     bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); |     bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); | ||||||
|  |     bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); | ||||||
|  |  | ||||||
| private: | private: | ||||||
|     bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data); |     bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data); | ||||||
| @@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) { | ||||||
|  |     uint32_t dev_id = request.device; | ||||||
|  |     if (dev_id >= backends.size()) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     size_t free, total; | ||||||
|  |     ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]); | ||||||
|  |     ggml_backend_dev_memory(dev, &free, &total); | ||||||
|  |     response.free_mem = free; | ||||||
|  |     response.total_mem = total; | ||||||
|  |     LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem); | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  |  | ||||||
| rpc_server::~rpc_server() { | rpc_server::~rpc_server() { | ||||||
|     for (auto buffer : buffers) { |     for (auto buffer : buffers) { | ||||||
|         ggml_backend_buffer_free(buffer); |         ggml_backend_buffer_free(buffer); | ||||||
| @@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() { | |||||||
| } | } | ||||||
|  |  | ||||||
| static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir, | static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir, | ||||||
|                              sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) { |                              sockfd_t sockfd) { | ||||||
|     rpc_server server(backends, cache_dir); |     rpc_server server(backends, cache_dir); | ||||||
|     uint8_t cmd; |     uint8_t cmd; | ||||||
|     if (!recv_data(sockfd, &cmd, 1)) { |     if (!recv_data(sockfd, &cmd, 1)) { | ||||||
| @@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const | |||||||
|                 if (!recv_msg(sockfd, &request, sizeof(request))) { |                 if (!recv_msg(sockfd, &request, sizeof(request))) { | ||||||
|                     return; |                     return; | ||||||
|                 } |                 } | ||||||
|                 auto dev_id = request.device; |                 rpc_msg_get_device_memory_rsp response; | ||||||
|                 if (dev_id >= backends.size()) { |                 if (!server.get_device_memory(request, response)) { | ||||||
|                     return; |                     return; | ||||||
|                 } |                 } | ||||||
|                 rpc_msg_get_device_memory_rsp response; |  | ||||||
|                 response.free_mem = free_mem[dev_id]; |  | ||||||
|                 response.total_mem = total_mem[dev_id]; |  | ||||||
|                 LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id, |  | ||||||
|                     response.free_mem, response.total_mem); |  | ||||||
|                 if (!send_msg(sockfd, &response, sizeof(response))) { |                 if (!send_msg(sockfd, &response, sizeof(response))) { | ||||||
|                     return; |                     return; | ||||||
|                 } |                 } | ||||||
| @@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const | |||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, | void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, | ||||||
|                                    size_t n_threads, size_t n_devices, |                                    size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) { | ||||||
|                                    ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) { |     if (n_devices == 0 || devices == nullptr) { | ||||||
|     if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) { |  | ||||||
|         fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); |         fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     std::vector<ggml_backend_t> backends; |     std::vector<ggml_backend_t> backends; | ||||||
|     std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices); |  | ||||||
|     std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices); |  | ||||||
|     printf("Starting RPC server v%d.%d.%d\n", |     printf("Starting RPC server v%d.%d.%d\n", | ||||||
|         RPC_PROTO_MAJOR_VERSION, |         RPC_PROTO_MAJOR_VERSION, | ||||||
|         RPC_PROTO_MINOR_VERSION, |         RPC_PROTO_MINOR_VERSION, | ||||||
| @@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir | |||||||
|     printf("Devices:\n"); |     printf("Devices:\n"); | ||||||
|     for (size_t i = 0; i < n_devices; i++) { |     for (size_t i = 0; i < n_devices; i++) { | ||||||
|         auto dev = devices[i]; |         auto dev = devices[i]; | ||||||
|  |         size_t free, total; | ||||||
|  |         ggml_backend_dev_memory(dev, &free, &total); | ||||||
|         printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), |         printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), | ||||||
|                total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024); |                total / 1024 / 1024, free / 1024 / 1024); | ||||||
|         auto backend = ggml_backend_dev_init(dev, nullptr); |         auto backend = ggml_backend_dev_init(dev, nullptr); | ||||||
|         if (!backend) { |         if (!backend) { | ||||||
|             fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); |             fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); | ||||||
| @@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir | |||||||
|         } |         } | ||||||
|         printf("Accepted client connection\n"); |         printf("Accepted client connection\n"); | ||||||
|         fflush(stdout); |         fflush(stdout); | ||||||
|         rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec); |         rpc_serve_client(backends, cache_dir, client_socket->fd); | ||||||
|         printf("Client connection closed\n"); |         printf("Client connection closed\n"); | ||||||
|         fflush(stdout); |         fflush(stdout); | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -137,7 +137,6 @@ struct rpc_server_params { | |||||||
|     bool                     use_cache   = false; |     bool                     use_cache   = false; | ||||||
|     int                      n_threads   = std::max(1U, std::thread::hardware_concurrency()/2); |     int                      n_threads   = std::max(1U, std::thread::hardware_concurrency()/2); | ||||||
|     std::vector<std::string> devices; |     std::vector<std::string> devices; | ||||||
|     std::vector<size_t>      dev_mem; |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { | static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { | ||||||
| @@ -148,7 +147,6 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { | |||||||
|     fprintf(stderr, "  -d, --device <dev1,dev2,...>     comma-separated list of devices\n"); |     fprintf(stderr, "  -d, --device <dev1,dev2,...>     comma-separated list of devices\n"); | ||||||
|     fprintf(stderr, "  -H, --host HOST                  host to bind to (default: %s)\n", params.host.c_str()); |     fprintf(stderr, "  -H, --host HOST                  host to bind to (default: %s)\n", params.host.c_str()); | ||||||
|     fprintf(stderr, "  -p, --port PORT                  port to bind to (default: %d)\n", params.port); |     fprintf(stderr, "  -p, --port PORT                  port to bind to (default: %d)\n", params.port); | ||||||
|     fprintf(stderr, "  -m, --mem <M1,M2,...>            memory size for each device (in MB)\n"); |  | ||||||
|     fprintf(stderr, "  -c, --cache                      enable local file cache\n"); |     fprintf(stderr, "  -c, --cache                      enable local file cache\n"); | ||||||
|     fprintf(stderr, "\n"); |     fprintf(stderr, "\n"); | ||||||
| } | } | ||||||
| @@ -197,23 +195,6 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & | |||||||
|             } |             } | ||||||
|         } else if (arg == "-c" || arg == "--cache") { |         } else if (arg == "-c" || arg == "--cache") { | ||||||
|             params.use_cache = true; |             params.use_cache = true; | ||||||
|         } else if (arg == "-m" || arg == "--mem") { |  | ||||||
|             if (++i >= argc) { |  | ||||||
|                 return false; |  | ||||||
|             } |  | ||||||
|             const std::regex regex{ R"([,/]+)" }; |  | ||||||
|             std::string mem_str = argv[i]; |  | ||||||
|             std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1); |  | ||||||
|             std::sregex_token_iterator end; |  | ||||||
|             for ( ; iter != end; ++iter) { |  | ||||||
|                 try { |  | ||||||
|                     size_t mem = std::stoul(*iter) * 1024 * 1024; |  | ||||||
|                     params.dev_mem.push_back(mem); |  | ||||||
|                 } catch (const std::exception & ) { |  | ||||||
|                     fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str()); |  | ||||||
|                     return false; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } else if (arg == "-h" || arg == "--help") { |         } else if (arg == "-h" || arg == "--help") { | ||||||
|             print_usage(argc, argv, params); |             print_usage(argc, argv, params); | ||||||
|             exit(0); |             exit(0); | ||||||
| @@ -293,18 +274,6 @@ int main(int argc, char * argv[]) { | |||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|     std::string endpoint = params.host + ":" + std::to_string(params.port); |     std::string endpoint = params.host + ":" + std::to_string(params.port); | ||||||
|     std::vector<size_t> free_mem, total_mem; |  | ||||||
|     for (size_t i = 0; i < devices.size(); i++) { |  | ||||||
|         if (i < params.dev_mem.size()) { |  | ||||||
|             free_mem.push_back(params.dev_mem[i]); |  | ||||||
|             total_mem.push_back(params.dev_mem[i]); |  | ||||||
|         } else { |  | ||||||
|             size_t free, total; |  | ||||||
|             ggml_backend_dev_memory(devices[i], &free, &total); |  | ||||||
|             free_mem.push_back(free); |  | ||||||
|             total_mem.push_back(total); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     const char * cache_dir = nullptr; |     const char * cache_dir = nullptr; | ||||||
|     std::string cache_dir_str; |     std::string cache_dir_str; | ||||||
|     if (params.use_cache) { |     if (params.use_cache) { | ||||||
| @@ -328,7 +297,6 @@ int main(int argc, char * argv[]) { | |||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), |     start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data()); | ||||||
|         devices.data(), free_mem.data(), total_mem.data()); |  | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user