rpc : add support for multiple devices (#16276)

* rpc : add support for multiple devices

Allow rpc-server to expose multiple devices from a single endpoint.
Change RPC protocol to include device identifier where needed.

closes: #15210

* fixes

* use ggml_backend_reg_t

* address review comments

* fix llama-bench backend report

* address review comments, change device naming

* fix cmd order
This commit is contained in:
Radoslav Gerganov
2025-10-04 12:49:16 +03:00
committed by GitHub
parent e29acf74fe
commit 898acba681
7 changed files with 403 additions and 245 deletions

View File

@@ -168,7 +168,7 @@ static std::vector<ggml_backend_dev_t> parse_devices_arg(const std::string & val
return devices;
}
static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::string & servers) {
static void register_rpc_server_list(const std::string & servers) {
auto rpc_servers = string_split<std::string>(servers, ',');
if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified");
@@ -179,36 +179,15 @@ static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::strin
throw std::invalid_argument("failed to find RPC backend");
}
using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint);
auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
throw std::invalid_argument("failed to find RPC device add function");
using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint);
auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
if (!ggml_backend_rpc_add_server_fn) {
throw std::invalid_argument("failed to find RPC add server function");
}
static std::unordered_set<std::string> registered;
std::vector<ggml_backend_dev_t> devices;
for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = nullptr;
std::string name = string_format("RPC[%s]", server.c_str());
if (registered.find(server) != registered.end()) {
dev = ggml_backend_dev_by_name(name.c_str());
}
if (!dev) {
dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (!dev) {
throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str()));
}
ggml_backend_device_register(dev);
registered.insert(server);
}
devices.push_back(dev);
auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
ggml_backend_register(reg);
}
return devices;
}
static std::string devices_to_string(const std::vector<ggml_backend_dev_t> & devices) {
@@ -714,7 +693,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
try {
register_rpc_device_list(argv[i]);
register_rpc_server_list(argv[i]);
} catch (const std::exception & e) {
fprintf(stderr, "error: %s\n", e.what());
invalid_param = true;
@@ -1368,13 +1347,23 @@ struct test {
static std::string get_backend() {
std::vector<std::string> backends;
bool rpc_used = false;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto * reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
if (string_starts_with(name, "RPC")) {
if (ggml_backend_reg_dev_count(reg) > 0) {
rpc_used = true;
}
} else {
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
}
}
}
if (rpc_used) {
backends.push_back("RPC");
}
return backends.empty() ? "CPU" : join(backends, ",");
}