rpc : add support for multiple devices (#16276)

* rpc : add support for multiple devices

Allow rpc-server to expose multiple devices from a single endpoint.
Change RPC protocol to include device identifier where needed.

closes: #15210

* fixes

* use ggml_backend_reg_t

* address review comments

* fix llama-bench backend report

* address review comments, change device naming

* fix cmd order
This commit is contained in:
Radoslav Gerganov
2025-10-04 12:49:16 +03:00
committed by GitHub
parent e29acf74fe
commit 898acba681
7 changed files with 403 additions and 245 deletions

View File

@@ -1615,18 +1615,14 @@ static void add_rpc_devices(const std::string & servers) {
if (!rpc_reg) { if (!rpc_reg) {
throw std::invalid_argument("failed to find RPC backend"); throw std::invalid_argument("failed to find RPC backend");
} }
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
if (!ggml_backend_rpc_add_device_fn) { if (!ggml_backend_rpc_add_server_fn) {
throw std::invalid_argument("failed to find RPC device add function"); throw std::invalid_argument("failed to find RPC add server function");
} }
for (const auto & server : rpc_servers) { for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
if (dev) { ggml_backend_register(reg);
ggml_backend_device_register(dev);
} else {
throw std::invalid_argument("failed to register RPC device");
}
} }
} }

View File

@@ -215,6 +215,8 @@ extern "C" {
// Backend registry // Backend registry
// //
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
// Backend (reg) enumeration // Backend (reg) enumeration

View File

@@ -7,26 +7,25 @@
extern "C" { extern "C" {
#endif #endif
#define RPC_PROTO_MAJOR_VERSION 2 #define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16 #define GGML_RPC_MAX_SERVERS 16
// backend API // backend API
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
const char * cache_dir, size_t n_threads, size_t n_devices,
size_t free_mem, size_t total_mem); ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@@ -209,9 +209,6 @@ extern "C" {
void * context; void * context;
}; };
// Internal backend registry API
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
// Add backend dynamic loading support to the backend // Add backend dynamic loading support to the backend
// Initialize the backend // Initialize the backend

View File

@@ -105,9 +105,12 @@ enum rpc_cmd {
RPC_CMD_INIT_TENSOR, RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO, RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_COUNT, RPC_CMD_COUNT,
}; };
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024; const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
@@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp {
uint8_t patch; uint8_t patch;
}; };
struct rpc_msg_device_count_rsp {
uint32_t device_count;
};
struct rpc_msg_get_alloc_size_req { struct rpc_msg_get_alloc_size_req {
uint32_t device;
rpc_tensor tensor; rpc_tensor tensor;
}; };
@@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req {
}; };
struct rpc_msg_alloc_buffer_req { struct rpc_msg_alloc_buffer_req {
uint32_t device;
uint64_t size; uint64_t size;
}; };
@@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp {
uint64_t remote_size; uint64_t remote_size;
}; };
struct rpc_msg_get_alignment_req {
uint32_t device;
};
struct rpc_msg_get_alignment_rsp { struct rpc_msg_get_alignment_rsp {
uint64_t alignment; uint64_t alignment;
}; };
struct rpc_msg_get_max_size_req {
uint32_t device;
};
struct rpc_msg_get_max_size_rsp { struct rpc_msg_get_max_size_rsp {
uint64_t max_size; uint64_t max_size;
}; };
@@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp {
uint8_t result; uint8_t result;
}; };
struct rpc_msg_get_device_memory_req {
uint32_t device;
};
struct rpc_msg_get_device_memory_rsp { struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem; uint64_t free_mem;
uint64_t total_mem; uint64_t total_mem;
@@ -207,6 +228,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
struct ggml_backend_rpc_buffer_type_context { struct ggml_backend_rpc_buffer_type_context {
std::string endpoint; std::string endpoint;
uint32_t device;
std::string name; std::string name;
size_t alignment; size_t alignment;
size_t max_size; size_t max_size;
@@ -214,6 +236,7 @@ struct ggml_backend_rpc_buffer_type_context {
struct ggml_backend_rpc_context { struct ggml_backend_rpc_context {
std::string endpoint; std::string endpoint;
uint32_t device;
std::string name; std::string name;
}; };
@@ -653,7 +676,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
rpc_msg_alloc_buffer_req request = {size}; rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
rpc_msg_alloc_buffer_rsp response; rpc_msg_alloc_buffer_rsp response;
auto sock = get_socket(buft_ctx->endpoint); auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
@@ -669,9 +692,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
} }
} }
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) { static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_alignment_req request = {device};
rpc_msg_get_alignment_rsp response; rpc_msg_get_alignment_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status); RPC_STATUS_ASSERT(status);
return response.alignment; return response.alignment;
} }
@@ -681,9 +705,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
return buft_ctx->alignment; return buft_ctx->alignment;
} }
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) { static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_max_size_req request = {device};
rpc_msg_get_max_size_rsp response; rpc_msg_get_max_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status); RPC_STATUS_ASSERT(status);
return response.max_size; return response.max_size;
} }
@@ -700,7 +725,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
auto sock = get_socket(buft_ctx->endpoint); auto sock = get_socket(buft_ctx->endpoint);
rpc_msg_get_alloc_size_req request; rpc_msg_get_alloc_size_req request;
request.device = buft_ctx->device;
request.tensor = serialize_tensor(tensor); request.tensor = serialize_tensor(tensor);
rpc_msg_get_alloc_size_rsp response; rpc_msg_get_alloc_size_rsp response;
@@ -754,7 +779,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
tensors.push_back(serialize_tensor(tensor)); tensors.push_back(serialize_tensor(tensor));
} }
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) { static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
uint32_t n_nodes = cgraph->n_nodes; uint32_t n_nodes = cgraph->n_nodes;
std::vector<rpc_tensor> tensors; std::vector<rpc_tensor> tensors;
std::unordered_set<ggml_tensor*> visited; std::unordered_set<ggml_tensor*> visited;
@@ -762,24 +787,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
add_tensor(cgraph->nodes[i], tensors, visited); add_tensor(cgraph->nodes[i], tensors, visited);
} }
// serialization format: // serialization format:
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
uint32_t n_tensors = tensors.size(); uint32_t n_tensors = tensors.size();
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
output.resize(output_size, 0); output.resize(output_size, 0);
memcpy(output.data(), &n_nodes, sizeof(n_nodes)); uint8_t * dest = output.data();
memcpy(dest, &device, sizeof(device));
dest += sizeof(device);
memcpy(dest, &n_nodes, sizeof(n_nodes));
dest += sizeof(n_nodes);
for (uint32_t i = 0; i < n_nodes; i++) { for (uint32_t i = 0; i < n_nodes; i++) {
memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
} }
uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); dest += n_nodes * sizeof(uint64_t);
*out_ntensors = n_tensors; memcpy(dest, &n_tensors, sizeof(n_tensors));
rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); dest += sizeof(n_tensors);
rpc_tensor * out_tensors = (rpc_tensor *)dest;
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
} }
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input; std::vector<uint8_t> input;
serialize_graph(cgraph, input); serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response; rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint); auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
@@ -804,12 +834,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .graph_optimize = */ NULL, /* .graph_optimize = */ NULL,
}; };
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
// NOTE: buffer types are allocated and never freed; this is by design // NOTE: buffer types are allocated and never freed; this is by design
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map; static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
auto it = buft_map.find(endpoint); auto it = buft_map.find(buft_name);
if (it != buft_map.end()) { if (it != buft_map.end()) {
return it->second; return it->second;
} }
@@ -818,34 +849,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return nullptr; return nullptr;
} }
size_t alignment = get_alignment(sock); size_t alignment = get_alignment(sock, device);
size_t max_size = get_max_size(sock); size_t max_size = get_max_size(sock, device);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .endpoint = */ endpoint, /* .endpoint = */ endpoint,
/* .name = */ "RPC[" + std::string(endpoint) + "]", /* .device = */ device,
/* .name = */ buft_name,
/* .alignment = */ alignment, /* .alignment = */ alignment,
/* .max_size = */ max_size /* .max_size = */ max_size
}; };
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
/* .iface = */ ggml_backend_rpc_buffer_type_interface, /* .iface = */ ggml_backend_rpc_buffer_type_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint), /* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ buft_ctx /* .context = */ buft_ctx
}; };
buft_map[endpoint] = buft; buft_map[buft_name] = buft;
return buft; return buft;
} }
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint, /* .endpoint = */ endpoint,
/* .name = */ "RPC[" + std::string(endpoint) + "]", /* .device = */ device,
/* .name = */ dev_name
}; };
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend { ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(), /* .guid = */ ggml_backend_rpc_guid(),
/* .iface = */ ggml_backend_rpc_interface, /* .iface = */ ggml_backend_rpc_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint), /* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ ctx /* .context = */ ctx
}; };
return backend; return backend;
@@ -855,37 +889,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
} }
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) { static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
rpc_msg_get_device_memory_req request;
request.device = device;
rpc_msg_get_device_memory_rsp response; rpc_msg_get_device_memory_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status); RPC_STATUS_ASSERT(status);
*free = response.free_mem; *free = response.free_mem;
*total = response.total_mem; *total = response.total_mem;
} }
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
auto sock = get_socket(endpoint); auto sock = get_socket(endpoint);
if (sock == nullptr) { if (sock == nullptr) {
*free = 0; *free = 0;
*total = 0; *total = 0;
return; return;
} }
get_device_memory(sock, free, total); get_device_memory(sock, device, free, total);
} }
// RPC server-side implementation // RPC server-side implementation
class rpc_server { class rpc_server {
public: public:
rpc_server(ggml_backend_t backend, const char * cache_dir) rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
: backend(backend), cache_dir(cache_dir) { : backends(std::move(backends)), cache_dir(cache_dir) {
} }
~rpc_server(); ~rpc_server();
void hello(rpc_msg_hello_rsp & response); void hello(rpc_msg_hello_rsp & response);
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
void get_alignment(rpc_msg_get_alignment_rsp & response); bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
void get_max_size(rpc_msg_get_max_size_rsp & response); bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
bool free_buffer(const rpc_msg_free_buffer_req & request); bool free_buffer(const rpc_msg_free_buffer_req & request);
bool buffer_clear(const rpc_msg_buffer_clear_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request);
@@ -906,7 +942,7 @@ private:
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map); std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
ggml_backend_t backend; std::vector<ggml_backend_t> backends;
const char * cache_dir; const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers; std::unordered_set<ggml_backend_buffer_t> buffers;
}; };
@@ -919,6 +955,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
} }
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft; ggml_backend_buffer_type_t buft;
struct ggml_init_params params { struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(), /*.mem_size =*/ ggml_tensor_overhead(),
@@ -935,10 +975,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
return false; return false;
} }
LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data); LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
if (tensor->buffer == nullptr) { if (tensor->buffer == nullptr) {
//No buffer allocated. //No buffer allocated.
buft = ggml_backend_get_default_buffer_type(backend); buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
} else { } else {
buft = tensor->buffer->buft; buft = tensor->buffer->buft;
} }
@@ -948,33 +988,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
return true; return true;
} }
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
response.remote_ptr = 0; response.remote_ptr = 0;
response.remote_size = 0; response.remote_size = 0;
if (buffer != nullptr) { if (buffer != nullptr) {
response.remote_ptr = reinterpret_cast<uint64_t>(buffer); response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
response.remote_size = buffer->size; response.remote_size = buffer->size;
LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
__func__, dev_id, request.size, response.remote_ptr, response.remote_size);
buffers.insert(buffer); buffers.insert(buffer);
} else { } else {
LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
} }
return true;
} }
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) { bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t alignment = ggml_backend_buft_get_alignment(buft); size_t alignment = ggml_backend_buft_get_alignment(buft);
LOG_DBG("[%s] alignment: %lu\n", __func__, alignment); LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
response.alignment = alignment; response.alignment = alignment;
return true;
} }
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t max_size = ggml_backend_buft_get_max_size(buft); size_t max_size = ggml_backend_buft_get_max_size(buft);
LOG_DBG("[%s] max_size: %lu\n", __func__, max_size); LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
response.max_size = max_size; response.max_size = max_size;
return true;
} }
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
@@ -1332,23 +1388,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) { bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
// serialization format: // serialization format:
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < sizeof(uint32_t)) { if (input.size() < 2*sizeof(uint32_t)) {
return false;
}
const uint8_t * src = input.data();
uint32_t device;
memcpy(&device, src, sizeof(device));
src += sizeof(device);
if (device >= backends.size()) {
return false; return false;
} }
uint32_t n_nodes; uint32_t n_nodes;
memcpy(&n_nodes, input.data(), sizeof(n_nodes)); memcpy(&n_nodes, src, sizeof(n_nodes));
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { src += sizeof(n_nodes);
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
return false; return false;
} }
const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); const uint64_t * nodes = (const uint64_t *)src;
src += n_nodes*sizeof(uint64_t);
uint32_t n_tensors; uint32_t n_tensors;
memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); memcpy(&n_tensors, src, sizeof(n_tensors));
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { src += sizeof(n_tensors);
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
return false; return false;
} }
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); const rpc_tensor * tensors = (const rpc_tensor *)src;
LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
@@ -1380,7 +1446,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return false; return false;
} }
} }
ggml_status status = ggml_backend_graph_compute(backend, graph); ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status; response.result = status;
return true; return true;
} }
@@ -1391,9 +1457,9 @@ rpc_server::~rpc_server() {
} }
} }
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd, size_t free_mem, size_t total_mem) { sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
rpc_server server(backend, cache_dir); rpc_server server(backends, cache_dir);
uint8_t cmd; uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) { if (!recv_data(sockfd, &cmd, 1)) {
return; return;
@@ -1425,13 +1491,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
// HELLO command is handled above // HELLO command is handled above
return; return;
} }
case RPC_CMD_DEVICE_COUNT: {
if (!recv_msg(sockfd, nullptr, 0)) {
return;
}
rpc_msg_device_count_rsp response;
response.device_count = backends.size();
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_ALLOC_BUFFER: { case RPC_CMD_ALLOC_BUFFER: {
rpc_msg_alloc_buffer_req request; rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) { if (!recv_msg(sockfd, &request, sizeof(request))) {
return; return;
} }
rpc_msg_alloc_buffer_rsp response; rpc_msg_alloc_buffer_rsp response;
server.alloc_buffer(request, response); if (!server.alloc_buffer(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) { if (!send_msg(sockfd, &response, sizeof(response))) {
return; return;
} }
@@ -1452,22 +1531,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break; break;
} }
case RPC_CMD_GET_ALIGNMENT: { case RPC_CMD_GET_ALIGNMENT: {
if (!recv_msg(sockfd, nullptr, 0)) { rpc_msg_get_alignment_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return; return;
} }
rpc_msg_get_alignment_rsp response; rpc_msg_get_alignment_rsp response;
server.get_alignment(response); if (!server.get_alignment(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) { if (!send_msg(sockfd, &response, sizeof(response))) {
return; return;
} }
break; break;
} }
case RPC_CMD_GET_MAX_SIZE: { case RPC_CMD_GET_MAX_SIZE: {
if (!recv_msg(sockfd, nullptr, 0)) { rpc_msg_get_max_size_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return; return;
} }
rpc_msg_get_max_size_rsp response; rpc_msg_get_max_size_rsp response;
server.get_max_size(response); if (!server.get_max_size(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) { if (!send_msg(sockfd, &response, sizeof(response))) {
return; return;
} }
@@ -1593,12 +1678,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break; break;
} }
case RPC_CMD_GET_DEVICE_MEMORY: { case RPC_CMD_GET_DEVICE_MEMORY: {
if (!recv_msg(sockfd, nullptr, 0)) { rpc_msg_get_device_memory_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
auto dev_id = request.device;
if (dev_id >= backends.size()) {
return; return;
} }
rpc_msg_get_device_memory_rsp response; rpc_msg_get_device_memory_rsp response;
response.free_mem = free_mem; response.free_mem = free_mem[dev_id];
response.total_mem = total_mem; response.total_mem = total_mem[dev_id];
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) { if (!send_msg(sockfd, &response, sizeof(response))) {
return; return;
} }
@@ -1612,16 +1704,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
} }
} }
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
const char * cache_dir, size_t n_threads, size_t n_devices,
size_t free_mem, size_t total_mem) { ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n", printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION, RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION, RPC_PROTO_MINOR_VERSION,
RPC_PROTO_PATCH_VERSION); RPC_PROTO_PATCH_VERSION);
printf(" endpoint : %s\n", endpoint); printf(" endpoint : %s\n", endpoint);
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
return;
}
backends.push_back(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ggml_backend_set_n_threads_fn(backend, n_threads);
}
}
}
std::string host; std::string host;
int port; int port;
@@ -1649,22 +1766,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
fprintf(stderr, "Failed to accept client connection\n"); fprintf(stderr, "Failed to accept client connection\n");
return; return;
} }
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); printf("Accepted client connection\n");
fflush(stdout); fflush(stdout);
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
printf("Client connection closed\n"); printf("Client connection closed\n");
fflush(stdout); fflush(stdout);
} }
#ifdef _WIN32 #ifdef _WIN32
WSACleanup(); WSACleanup();
#endif #endif
for (auto backend : backends) {
ggml_backend_free(backend);
}
} }
// device interface // device interface
struct ggml_backend_rpc_device_context { struct ggml_backend_rpc_device_context {
std::string endpoint; std::string endpoint;
uint32_t device;
std::string name; std::string name;
std::string description;
}; };
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
@@ -1676,15 +1798,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ctx->name.c_str(); return ctx->description.c_str();
} }
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
GGML_UNUSED(dev);
} }
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
@@ -1710,7 +1830,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ggml_backend_rpc_init(ctx->endpoint.c_str()); return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(params); GGML_UNUSED(params);
} }
@@ -1718,7 +1838,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(dev); GGML_UNUSED(dev);
} }
@@ -1736,7 +1856,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
} }
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
return buft_ctx->endpoint == dev_ctx->endpoint; return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
} }
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
@@ -1759,28 +1879,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
// backend reg interface // backend reg interface
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { struct ggml_backend_rpc_reg_context {
return "RPC"; std::string name;
std::vector<ggml_backend_dev_t> devices;
};
GGML_UNUSED(reg); static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
return ctx ? ctx->name.c_str() : "RPC";
} }
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
return 0; ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
return ctx ? ctx->devices.size() : 0;
GGML_UNUSED(reg);
} }
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
if (ctx == nullptr) {
GGML_UNUSED(reg); GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
GGML_UNUSED(index); } else {
GGML_ASSERT(index < ctx->devices.size());
return ctx->devices[index];
}
} }
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
return (void *)ggml_backend_rpc_add_device; return (void *)ggml_backend_rpc_add_server;
} }
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
return (void *)ggml_backend_rpc_start_server; return (void *)ggml_backend_rpc_start_server;
@@ -1807,30 +1933,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
return &ggml_backend_rpc_reg; return &ggml_backend_rpc_reg;
} }
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) { static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map; auto sock = get_socket(endpoint);
rpc_msg_device_count_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.device_count;
}
static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
/* .get_name = */ ggml_backend_rpc_reg_get_name,
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
/* .get_device = */ ggml_backend_rpc_reg_get_device,
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
};
ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
static std::mutex mutex; static std::mutex mutex;
static uint32_t dev_id = 0;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (reg_map.find(endpoint) != reg_map.end()) {
if (dev_map.find(endpoint) != dev_map.end()) { return reg_map[endpoint];
return dev_map[endpoint];
} }
uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context { if (dev_count == 0) {
return nullptr;
}
ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
ctx->name = "RPC[" + std::string(endpoint) + "]";
for (uint32_t ind = 0; ind < dev_count; ind++) {
std::string dev_name = "RPC" + std::to_string(dev_id);
std::string dev_desc = std::string(endpoint);
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
/* .endpoint = */ endpoint, /* .endpoint = */ endpoint,
/* .name = */ "RPC[" + std::string(endpoint) + "]", /* .device = */ ind,
/* .name = */ dev_name,
/* .description = */ dev_desc
}; };
ggml_backend_dev_t dev = new ggml_backend_device { ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_rpc_device_i, /* .iface = */ ggml_backend_rpc_device_i,
/* .reg = */ ggml_backend_rpc_reg(), /* .reg = */ ggml_backend_rpc_reg(),
/* .context = */ ctx, /* .context = */ dev_ctx,
}; };
ctx->devices.push_back(dev);
dev_map[endpoint] = dev; dev_id++;
}
return dev; ggml_backend_reg_t reg = new ggml_backend_reg {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_rpc_reg_interface,
/* .context = */ ctx
};
reg_map[endpoint] = reg;
return reg;
} }
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)

View File

@@ -168,7 +168,7 @@ static std::vector<ggml_backend_dev_t> parse_devices_arg(const std::string & val
return devices; return devices;
} }
static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::string & servers) { static void register_rpc_server_list(const std::string & servers) {
auto rpc_servers = string_split<std::string>(servers, ','); auto rpc_servers = string_split<std::string>(servers, ',');
if (rpc_servers.empty()) { if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified"); throw std::invalid_argument("no RPC servers specified");
@@ -179,36 +179,15 @@ static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::strin
throw std::invalid_argument("failed to find RPC backend"); throw std::invalid_argument("failed to find RPC backend");
} }
using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint); using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint);
auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
if (!ggml_backend_rpc_add_device_fn) { if (!ggml_backend_rpc_add_server_fn) {
throw std::invalid_argument("failed to find RPC device add function"); throw std::invalid_argument("failed to find RPC add server function");
} }
static std::unordered_set<std::string> registered;
std::vector<ggml_backend_dev_t> devices;
for (const auto & server : rpc_servers) { for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = nullptr; auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
ggml_backend_register(reg);
std::string name = string_format("RPC[%s]", server.c_str());
if (registered.find(server) != registered.end()) {
dev = ggml_backend_dev_by_name(name.c_str());
} }
if (!dev) {
dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (!dev) {
throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str()));
}
ggml_backend_device_register(dev);
registered.insert(server);
}
devices.push_back(dev);
}
return devices;
} }
static std::string devices_to_string(const std::vector<ggml_backend_dev_t> & devices) { static std::string devices_to_string(const std::vector<ggml_backend_dev_t> & devices) {
@@ -714,7 +693,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break; break;
} }
try { try {
register_rpc_device_list(argv[i]); register_rpc_server_list(argv[i]);
} catch (const std::exception & e) { } catch (const std::exception & e) {
fprintf(stderr, "error: %s\n", e.what()); fprintf(stderr, "error: %s\n", e.what());
invalid_param = true; invalid_param = true;
@@ -1368,13 +1347,23 @@ struct test {
static std::string get_backend() { static std::string get_backend() {
std::vector<std::string> backends; std::vector<std::string> backends;
bool rpc_used = false;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) { for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto * reg = ggml_backend_reg_get(i); auto * reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg); std::string name = ggml_backend_reg_name(reg);
if (string_starts_with(name, "RPC")) {
if (ggml_backend_reg_dev_count(reg) > 0) {
rpc_used = true;
}
} else {
if (name != "CPU") { if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg)); backends.push_back(ggml_backend_reg_name(reg));
} }
} }
}
if (rpc_used) {
backends.push_back("RPC");
}
return backends.empty() ? "CPU" : join(backends, ","); return backends.empty() ? "CPU" : join(backends, ",");
} }

View File

@@ -22,6 +22,7 @@
#include <filesystem> #include <filesystem>
#include <algorithm> #include <algorithm>
#include <thread> #include <thread>
#include <regex>
namespace fs = std::filesystem; namespace fs = std::filesystem;
@@ -133,21 +134,21 @@ static std::string fs_get_cache_directory() {
struct rpc_server_params { struct rpc_server_params {
std::string host = "127.0.0.1"; std::string host = "127.0.0.1";
int port = 50052; int port = 50052;
size_t backend_mem = 0;
bool use_cache = false; bool use_cache = false;
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
std::string device; std::vector<std::string> devices;
std::vector<size_t> dev_mem;
}; };
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads); fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
fprintf(stderr, " -d DEV, --device device to use\n"); fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@@ -174,18 +175,18 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
if (++i >= argc) { if (++i >= argc) {
return false; return false;
} }
params.device = argv[i]; const std::regex regex{ R"([,/]+)" };
if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) { std::string dev_str = argv[i];
fprintf(stderr, "error: unknown device: %s\n", params.device.c_str()); std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
fprintf(stderr, "available devices:\n"); std::sregex_token_iterator end;
for (size_t i = 0; i < ggml_backend_dev_count(); i++) { for ( ; iter != end; ++iter) {
auto * dev = ggml_backend_dev_get(i); try {
size_t free, total; params.devices.push_back(*iter);
ggml_backend_dev_memory(dev, &free, &total); } catch (const std::exception & ) {
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
}
return false; return false;
} }
}
} else if (arg == "-p" || arg == "--port") { } else if (arg == "-p" || arg == "--port") {
if (++i >= argc) { if (++i >= argc) {
return false; return false;
@@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
if (++i >= argc) { if (++i >= argc) {
return false; return false;
} }
params.backend_mem = std::stoul(argv[i]) * 1024 * 1024; const std::regex regex{ R"([,/]+)" };
std::string mem_str = argv[i];
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
std::sregex_token_iterator end;
for ( ; iter != end; ++iter) {
try {
size_t mem = std::stoul(*iter) * 1024 * 1024;
params.dev_mem.push_back(mem);
} catch (const std::exception & ) {
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
return false;
}
}
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params); print_usage(argc, argv, params);
exit(0); exit(0);
@@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
return true; return true;
} }
static ggml_backend_t create_backend(const rpc_server_params & params) { static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & params) {
ggml_backend_t backend = nullptr; std::vector<ggml_backend_dev_t> devices;
if (!params.devices.empty()) {
if (!params.device.empty()) { for (auto device : params.devices) {
ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str()); ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
if (dev) { if (dev) {
backend = ggml_backend_dev_init(dev, nullptr); devices.push_back(dev);
if (!backend) { } else {
fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str()); fprintf(stderr, "error: unknown device: %s\n", device.c_str());
return nullptr; fprintf(stderr, "available devices:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
auto * dev = ggml_backend_dev_get(i);
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
return {};
} }
} }
} }
if (!backend) { // Try non-CPU devices first
backend = ggml_backend_init_best(); if (devices.empty()) {
} for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (backend) { if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend)); devices.push_back(dev);
// set the number of threads
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ggml_backend_set_n_threads_fn(backend, params.n_threads);
} }
} }
} }
return backend; // If there are no accelerators, fallback to CPU device
} if (devices.empty()) {
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (dev) {
devices.push_back(dev);
}
}
static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) { return devices;
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
GGML_ASSERT(dev != nullptr);
ggml_backend_dev_memory(dev, free_mem, total_mem);
} }
int main(int argc, char * argv[]) { int main(int argc, char * argv[]) {
@@ -273,18 +287,23 @@ int main(int argc, char * argv[]) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
ggml_backend_t backend = create_backend(params); auto devices = get_devices(params);
if (!backend) { if (devices.empty()) {
fprintf(stderr, "Failed to create backend\n"); fprintf(stderr, "No devices found\n");
return 1; return 1;
} }
std::string endpoint = params.host + ":" + std::to_string(params.port); std::string endpoint = params.host + ":" + std::to_string(params.port);
size_t free_mem, total_mem; std::vector<size_t> free_mem, total_mem;
if (params.backend_mem > 0) { for (size_t i = 0; i < devices.size(); i++) {
free_mem = params.backend_mem; if (i < params.dev_mem.size()) {
total_mem = params.backend_mem; free_mem.push_back(params.dev_mem[i]);
total_mem.push_back(params.dev_mem[i]);
} else { } else {
get_backend_memory(backend, &free_mem, &total_mem); size_t free, total;
ggml_backend_dev_memory(devices[i], &free, &total);
free_mem.push_back(free);
total_mem.push_back(total);
}
} }
const char * cache_dir = nullptr; const char * cache_dir = nullptr;
std::string cache_dir_str; std::string cache_dir_str;
@@ -309,8 +328,7 @@ int main(int argc, char * argv[]) {
return 1; return 1;
} }
start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem); start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
devices.data(), free_mem.data(), total_mem.data());
ggml_backend_free(backend);
return 0; return 0;
} }