mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
rpc : report actual free memory (#16616)
* rpc : report actual free memory Start reporting the free memory on every device instead of using fixed values. Now llama-cli users can get a nice memory breakdown when using RPC devices. * drop --mem in rpc-server
This commit is contained in:
committed by
GitHub
parent
3d4e86bbeb
commit
41386cf365
@@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
|
|||||||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
||||||
|
|
||||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||||
size_t n_threads, size_t n_devices,
|
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
|
||||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
||||||
|
|||||||
@@ -939,6 +939,7 @@ public:
|
|||||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||||
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
||||||
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
||||||
|
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
||||||
@@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
|
||||||
|
uint32_t dev_id = request.device;
|
||||||
|
if (dev_id >= backends.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t free, total;
|
||||||
|
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
|
||||||
|
ggml_backend_dev_memory(dev, &free, &total);
|
||||||
|
response.free_mem = free;
|
||||||
|
response.total_mem = total;
|
||||||
|
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
rpc_server::~rpc_server() {
|
rpc_server::~rpc_server() {
|
||||||
for (auto buffer : buffers) {
|
for (auto buffer : buffers) {
|
||||||
ggml_backend_buffer_free(buffer);
|
ggml_backend_buffer_free(buffer);
|
||||||
@@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
|
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
|
||||||
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
|
sockfd_t sockfd) {
|
||||||
rpc_server server(backends, cache_dir);
|
rpc_server server(backends, cache_dir);
|
||||||
uint8_t cmd;
|
uint8_t cmd;
|
||||||
if (!recv_data(sockfd, &cmd, 1)) {
|
if (!recv_data(sockfd, &cmd, 1)) {
|
||||||
@@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto dev_id = request.device;
|
rpc_msg_get_device_memory_rsp response;
|
||||||
if (dev_id >= backends.size()) {
|
if (!server.get_device_memory(request, response)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
rpc_msg_get_device_memory_rsp response;
|
|
||||||
response.free_mem = free_mem[dev_id];
|
|
||||||
response.total_mem = total_mem[dev_id];
|
|
||||||
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
|
|
||||||
response.free_mem, response.total_mem);
|
|
||||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||||
size_t n_threads, size_t n_devices,
|
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
|
||||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
|
if (n_devices == 0 || devices == nullptr) {
|
||||||
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
|
|
||||||
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
|
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<ggml_backend_t> backends;
|
std::vector<ggml_backend_t> backends;
|
||||||
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
|
|
||||||
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
|
|
||||||
printf("Starting RPC server v%d.%d.%d\n",
|
printf("Starting RPC server v%d.%d.%d\n",
|
||||||
RPC_PROTO_MAJOR_VERSION,
|
RPC_PROTO_MAJOR_VERSION,
|
||||||
RPC_PROTO_MINOR_VERSION,
|
RPC_PROTO_MINOR_VERSION,
|
||||||
@@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||||||
printf("Devices:\n");
|
printf("Devices:\n");
|
||||||
for (size_t i = 0; i < n_devices; i++) {
|
for (size_t i = 0; i < n_devices; i++) {
|
||||||
auto dev = devices[i];
|
auto dev = devices[i];
|
||||||
|
size_t free, total;
|
||||||
|
ggml_backend_dev_memory(dev, &free, &total);
|
||||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
||||||
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
|
total / 1024 / 1024, free / 1024 / 1024);
|
||||||
auto backend = ggml_backend_dev_init(dev, nullptr);
|
auto backend = ggml_backend_dev_init(dev, nullptr);
|
||||||
if (!backend) {
|
if (!backend) {
|
||||||
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
|
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
|
||||||
@@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||||||
}
|
}
|
||||||
printf("Accepted client connection\n");
|
printf("Accepted client connection\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
|
rpc_serve_client(backends, cache_dir, client_socket->fd);
|
||||||
printf("Client connection closed\n");
|
printf("Client connection closed\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -137,7 +137,6 @@ struct rpc_server_params {
|
|||||||
bool use_cache = false;
|
bool use_cache = false;
|
||||||
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
|
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
|
||||||
std::vector<std::string> devices;
|
std::vector<std::string> devices;
|
||||||
std::vector<size_t> dev_mem;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
||||||
@@ -148,7 +147,6 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
|||||||
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
|
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
|
||||||
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||||
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
|
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
|
||||||
fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
|
|
||||||
fprintf(stderr, " -c, --cache enable local file cache\n");
|
fprintf(stderr, " -c, --cache enable local file cache\n");
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
@@ -197,23 +195,6 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
|||||||
}
|
}
|
||||||
} else if (arg == "-c" || arg == "--cache") {
|
} else if (arg == "-c" || arg == "--cache") {
|
||||||
params.use_cache = true;
|
params.use_cache = true;
|
||||||
} else if (arg == "-m" || arg == "--mem") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const std::regex regex{ R"([,/]+)" };
|
|
||||||
std::string mem_str = argv[i];
|
|
||||||
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
|
|
||||||
std::sregex_token_iterator end;
|
|
||||||
for ( ; iter != end; ++iter) {
|
|
||||||
try {
|
|
||||||
size_t mem = std::stoul(*iter) * 1024 * 1024;
|
|
||||||
params.dev_mem.push_back(mem);
|
|
||||||
} catch (const std::exception & ) {
|
|
||||||
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (arg == "-h" || arg == "--help") {
|
} else if (arg == "-h" || arg == "--help") {
|
||||||
print_usage(argc, argv, params);
|
print_usage(argc, argv, params);
|
||||||
exit(0);
|
exit(0);
|
||||||
@@ -293,18 +274,6 @@ int main(int argc, char * argv[]) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
std::string endpoint = params.host + ":" + std::to_string(params.port);
|
std::string endpoint = params.host + ":" + std::to_string(params.port);
|
||||||
std::vector<size_t> free_mem, total_mem;
|
|
||||||
for (size_t i = 0; i < devices.size(); i++) {
|
|
||||||
if (i < params.dev_mem.size()) {
|
|
||||||
free_mem.push_back(params.dev_mem[i]);
|
|
||||||
total_mem.push_back(params.dev_mem[i]);
|
|
||||||
} else {
|
|
||||||
size_t free, total;
|
|
||||||
ggml_backend_dev_memory(devices[i], &free, &total);
|
|
||||||
free_mem.push_back(free);
|
|
||||||
total_mem.push_back(total);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const char * cache_dir = nullptr;
|
const char * cache_dir = nullptr;
|
||||||
std::string cache_dir_str;
|
std::string cache_dir_str;
|
||||||
if (params.use_cache) {
|
if (params.use_cache) {
|
||||||
@@ -328,7 +297,6 @@ int main(int argc, char * argv[]) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
|
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
|
||||||
devices.data(), free_mem.data(), total_mem.data());
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user