rpc : add support for multiple devices (#16276)

* rpc : add support for multiple devices

Allow rpc-server to expose multiple devices from a single endpoint.
Change RPC protocol to include device identifier where needed.

closes: #15210

* fixes

* use ggml_backend_reg_t

* address review comments

* fix llama-bench backend report

* address review comments, change device naming

* fix cmd order
This commit is contained in:
Radoslav Gerganov
2025-10-04 12:49:16 +03:00
committed by GitHub
parent e29acf74fe
commit 898acba681
7 changed files with 403 additions and 245 deletions

View File

@@ -168,7 +168,7 @@ static std::vector<ggml_backend_dev_t> parse_devices_arg(const std::string & val
return devices;
}
static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::string & servers) {
static void register_rpc_server_list(const std::string & servers) {
auto rpc_servers = string_split<std::string>(servers, ',');
if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified");
@@ -179,36 +179,15 @@ static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::strin
throw std::invalid_argument("failed to find RPC backend");
}
using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint);
auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
throw std::invalid_argument("failed to find RPC device add function");
using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint);
auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
if (!ggml_backend_rpc_add_server_fn) {
throw std::invalid_argument("failed to find RPC add server function");
}
static std::unordered_set<std::string> registered;
std::vector<ggml_backend_dev_t> devices;
for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = nullptr;
std::string name = string_format("RPC[%s]", server.c_str());
if (registered.find(server) != registered.end()) {
dev = ggml_backend_dev_by_name(name.c_str());
}
if (!dev) {
dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (!dev) {
throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str()));
}
ggml_backend_device_register(dev);
registered.insert(server);
}
devices.push_back(dev);
auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
ggml_backend_register(reg);
}
return devices;
}
static std::string devices_to_string(const std::vector<ggml_backend_dev_t> & devices) {
@@ -714,7 +693,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
try {
register_rpc_device_list(argv[i]);
register_rpc_server_list(argv[i]);
} catch (const std::exception & e) {
fprintf(stderr, "error: %s\n", e.what());
invalid_param = true;
@@ -1368,13 +1347,23 @@ struct test {
static std::string get_backend() {
std::vector<std::string> backends;
bool rpc_used = false;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto * reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
if (string_starts_with(name, "RPC")) {
if (ggml_backend_reg_dev_count(reg) > 0) {
rpc_used = true;
}
} else {
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
}
}
}
if (rpc_used) {
backends.push_back("RPC");
}
return backends.empty() ? "CPU" : join(backends, ",");
}

View File

@@ -22,6 +22,7 @@
#include <filesystem>
#include <algorithm>
#include <thread>
#include <regex>
namespace fs = std::filesystem;
@@ -131,24 +132,24 @@ static std::string fs_get_cache_directory() {
}
struct rpc_server_params {
std::string host = "127.0.0.1";
int port = 50052;
size_t backend_mem = 0;
bool use_cache = false;
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
std::string device;
std::string host = "127.0.0.1";
int port = 50052;
bool use_cache = false;
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
std::vector<std::string> devices;
std::vector<size_t> dev_mem;
};
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads);
fprintf(stderr, " -d DEV, --device device to use\n");
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n");
}
@@ -174,17 +175,17 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
if (++i >= argc) {
return false;
}
params.device = argv[i];
if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) {
fprintf(stderr, "error: unknown device: %s\n", params.device.c_str());
fprintf(stderr, "available devices:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
auto * dev = ggml_backend_dev_get(i);
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
const std::regex regex{ R"([,/]+)" };
std::string dev_str = argv[i];
std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
std::sregex_token_iterator end;
for ( ; iter != end; ++iter) {
try {
params.devices.push_back(*iter);
} catch (const std::exception & ) {
fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
return false;
}
return false;
}
} else if (arg == "-p" || arg == "--port") {
if (++i >= argc) {
@@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
if (++i >= argc) {
return false;
}
params.backend_mem = std::stoul(argv[i]) * 1024 * 1024;
const std::regex regex{ R"([,/]+)" };
std::string mem_str = argv[i];
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
std::sregex_token_iterator end;
for ( ; iter != end; ++iter) {
try {
size_t mem = std::stoul(*iter) * 1024 * 1024;
params.dev_mem.push_back(mem);
} catch (const std::exception & ) {
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
return false;
}
}
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params);
exit(0);
@@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
return true;
}
static ggml_backend_t create_backend(const rpc_server_params & params) {
ggml_backend_t backend = nullptr;
static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & params) {
std::vector<ggml_backend_dev_t> devices;
if (!params.devices.empty()) {
for (auto device : params.devices) {
ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
if (dev) {
devices.push_back(dev);
} else {
fprintf(stderr, "error: unknown device: %s\n", device.c_str());
fprintf(stderr, "available devices:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
auto * dev = ggml_backend_dev_get(i);
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
return {};
}
}
}
if (!params.device.empty()) {
ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str());
// Try non-CPU devices first
if (devices.empty()) {
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
devices.push_back(dev);
}
}
}
// If there are no accelerators, fallback to CPU device
if (devices.empty()) {
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (dev) {
backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str());
return nullptr;
}
devices.push_back(dev);
}
}
if (!backend) {
backend = ggml_backend_init_best();
}
if (backend) {
fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend));
// set the number of threads
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ggml_backend_set_n_threads_fn(backend, params.n_threads);
}
}
}
return backend;
}
static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) {
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
GGML_ASSERT(dev != nullptr);
ggml_backend_dev_memory(dev, free_mem, total_mem);
return devices;
}
int main(int argc, char * argv[]) {
@@ -273,18 +287,23 @@ int main(int argc, char * argv[]) {
fprintf(stderr, "\n");
}
ggml_backend_t backend = create_backend(params);
if (!backend) {
fprintf(stderr, "Failed to create backend\n");
auto devices = get_devices(params);
if (devices.empty()) {
fprintf(stderr, "No devices found\n");
return 1;
}
std::string endpoint = params.host + ":" + std::to_string(params.port);
size_t free_mem, total_mem;
if (params.backend_mem > 0) {
free_mem = params.backend_mem;
total_mem = params.backend_mem;
} else {
get_backend_memory(backend, &free_mem, &total_mem);
std::vector<size_t> free_mem, total_mem;
for (size_t i = 0; i < devices.size(); i++) {
if (i < params.dev_mem.size()) {
free_mem.push_back(params.dev_mem[i]);
total_mem.push_back(params.dev_mem[i]);
} else {
size_t free, total;
ggml_backend_dev_memory(devices[i], &free, &total);
free_mem.push_back(free);
total_mem.push_back(total);
}
}
const char * cache_dir = nullptr;
std::string cache_dir_str;
@@ -309,8 +328,7 @@ int main(int argc, char * argv[]) {
return 1;
}
start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
ggml_backend_free(backend);
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
devices.data(), free_mem.data(), total_mem.data());
return 0;
}