mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
rpc : add support for multiple devices (#16276)
* rpc : add support for multiple devices Allow rpc-server to expose multiple devices from a single endpoint. Change RPC protocol to include device identifier where needed. closes: #15210 * fixes * use ggml_backend_reg_t * address review comments * fix llama-bench backend report * address review comments, change device naming * fix cmd order
This commit is contained in:
committed by
GitHub
parent
e29acf74fe
commit
898acba681
@@ -168,7 +168,7 @@ static std::vector<ggml_backend_dev_t> parse_devices_arg(const std::string & val
|
||||
return devices;
|
||||
}
|
||||
|
||||
static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::string & servers) {
|
||||
static void register_rpc_server_list(const std::string & servers) {
|
||||
auto rpc_servers = string_split<std::string>(servers, ',');
|
||||
if (rpc_servers.empty()) {
|
||||
throw std::invalid_argument("no RPC servers specified");
|
||||
@@ -179,36 +179,15 @@ static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::strin
|
||||
throw std::invalid_argument("failed to find RPC backend");
|
||||
}
|
||||
|
||||
using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint);
|
||||
auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
||||
if (!ggml_backend_rpc_add_device_fn) {
|
||||
throw std::invalid_argument("failed to find RPC device add function");
|
||||
using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint);
|
||||
auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
|
||||
if (!ggml_backend_rpc_add_server_fn) {
|
||||
throw std::invalid_argument("failed to find RPC add server function");
|
||||
}
|
||||
|
||||
static std::unordered_set<std::string> registered;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
for (const auto & server : rpc_servers) {
|
||||
ggml_backend_dev_t dev = nullptr;
|
||||
|
||||
std::string name = string_format("RPC[%s]", server.c_str());
|
||||
|
||||
if (registered.find(server) != registered.end()) {
|
||||
dev = ggml_backend_dev_by_name(name.c_str());
|
||||
}
|
||||
|
||||
if (!dev) {
|
||||
dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
||||
if (!dev) {
|
||||
throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str()));
|
||||
}
|
||||
ggml_backend_device_register(dev);
|
||||
registered.insert(server);
|
||||
}
|
||||
|
||||
devices.push_back(dev);
|
||||
auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
|
||||
ggml_backend_register(reg);
|
||||
}
|
||||
|
||||
return devices;
|
||||
}
|
||||
|
||||
static std::string devices_to_string(const std::vector<ggml_backend_dev_t> & devices) {
|
||||
@@ -714,7 +693,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
try {
|
||||
register_rpc_device_list(argv[i]);
|
||||
register_rpc_server_list(argv[i]);
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "error: %s\n", e.what());
|
||||
invalid_param = true;
|
||||
@@ -1368,13 +1347,23 @@ struct test {
|
||||
|
||||
static std::string get_backend() {
|
||||
std::vector<std::string> backends;
|
||||
bool rpc_used = false;
|
||||
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||
auto * reg = ggml_backend_reg_get(i);
|
||||
std::string name = ggml_backend_reg_name(reg);
|
||||
if (name != "CPU") {
|
||||
backends.push_back(ggml_backend_reg_name(reg));
|
||||
if (string_starts_with(name, "RPC")) {
|
||||
if (ggml_backend_reg_dev_count(reg) > 0) {
|
||||
rpc_used = true;
|
||||
}
|
||||
} else {
|
||||
if (name != "CPU") {
|
||||
backends.push_back(ggml_backend_reg_name(reg));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (rpc_used) {
|
||||
backends.push_back("RPC");
|
||||
}
|
||||
return backends.empty() ? "CPU" : join(backends, ",");
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include <filesystem>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <regex>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@@ -131,24 +132,24 @@ static std::string fs_get_cache_directory() {
|
||||
}
|
||||
|
||||
struct rpc_server_params {
|
||||
std::string host = "127.0.0.1";
|
||||
int port = 50052;
|
||||
size_t backend_mem = 0;
|
||||
bool use_cache = false;
|
||||
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
|
||||
std::string device;
|
||||
std::string host = "127.0.0.1";
|
||||
int port = 50052;
|
||||
bool use_cache = false;
|
||||
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
|
||||
std::vector<std::string> devices;
|
||||
std::vector<size_t> dev_mem;
|
||||
};
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
||||
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -d DEV, --device device to use\n");
|
||||
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
|
||||
fprintf(stderr, " -c, --cache enable local file cache\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
|
||||
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
|
||||
fprintf(stderr, " -c, --cache enable local file cache\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@@ -174,17 +175,17 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
params.device = argv[i];
|
||||
if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) {
|
||||
fprintf(stderr, "error: unknown device: %s\n", params.device.c_str());
|
||||
fprintf(stderr, "available devices:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
const std::regex regex{ R"([,/]+)" };
|
||||
std::string dev_str = argv[i];
|
||||
std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
|
||||
std::sregex_token_iterator end;
|
||||
for ( ; iter != end; ++iter) {
|
||||
try {
|
||||
params.devices.push_back(*iter);
|
||||
} catch (const std::exception & ) {
|
||||
fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} else if (arg == "-p" || arg == "--port") {
|
||||
if (++i >= argc) {
|
||||
@@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
params.backend_mem = std::stoul(argv[i]) * 1024 * 1024;
|
||||
const std::regex regex{ R"([,/]+)" };
|
||||
std::string mem_str = argv[i];
|
||||
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
|
||||
std::sregex_token_iterator end;
|
||||
for ( ; iter != end; ++iter) {
|
||||
try {
|
||||
size_t mem = std::stoul(*iter) * 1024 * 1024;
|
||||
params.dev_mem.push_back(mem);
|
||||
} catch (const std::exception & ) {
|
||||
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
@@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_backend_t create_backend(const rpc_server_params & params) {
|
||||
ggml_backend_t backend = nullptr;
|
||||
static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & params) {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
if (!params.devices.empty()) {
|
||||
for (auto device : params.devices) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
|
||||
if (dev) {
|
||||
devices.push_back(dev);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown device: %s\n", device.c_str());
|
||||
fprintf(stderr, "available devices:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.device.empty()) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str());
|
||||
// Try non-CPU devices first
|
||||
if (devices.empty()) {
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no accelerators, fallback to CPU device
|
||||
if (devices.empty()) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (dev) {
|
||||
backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (!backend) {
|
||||
fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
|
||||
if (!backend) {
|
||||
backend = ggml_backend_init_best();
|
||||
}
|
||||
|
||||
if (backend) {
|
||||
fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
|
||||
// set the number of threads
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
ggml_backend_set_n_threads_fn(backend, params.n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return backend;
|
||||
}
|
||||
|
||||
static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) {
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
||||
GGML_ASSERT(dev != nullptr);
|
||||
ggml_backend_dev_memory(dev, free_mem, total_mem);
|
||||
return devices;
|
||||
}
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
@@ -273,18 +287,23 @@ int main(int argc, char * argv[]) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
ggml_backend_t backend = create_backend(params);
|
||||
if (!backend) {
|
||||
fprintf(stderr, "Failed to create backend\n");
|
||||
auto devices = get_devices(params);
|
||||
if (devices.empty()) {
|
||||
fprintf(stderr, "No devices found\n");
|
||||
return 1;
|
||||
}
|
||||
std::string endpoint = params.host + ":" + std::to_string(params.port);
|
||||
size_t free_mem, total_mem;
|
||||
if (params.backend_mem > 0) {
|
||||
free_mem = params.backend_mem;
|
||||
total_mem = params.backend_mem;
|
||||
} else {
|
||||
get_backend_memory(backend, &free_mem, &total_mem);
|
||||
std::vector<size_t> free_mem, total_mem;
|
||||
for (size_t i = 0; i < devices.size(); i++) {
|
||||
if (i < params.dev_mem.size()) {
|
||||
free_mem.push_back(params.dev_mem[i]);
|
||||
total_mem.push_back(params.dev_mem[i]);
|
||||
} else {
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(devices[i], &free, &total);
|
||||
free_mem.push_back(free);
|
||||
total_mem.push_back(total);
|
||||
}
|
||||
}
|
||||
const char * cache_dir = nullptr;
|
||||
std::string cache_dir_str;
|
||||
@@ -309,8 +328,7 @@ int main(int argc, char * argv[]) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
|
||||
|
||||
ggml_backend_free(backend);
|
||||
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
|
||||
devices.data(), free_mem.data(), total_mem.data());
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user