mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	rpc : add support for multiple devices (#16276)
* rpc : add support for multiple devices Allow rpc-server to expose multiple devices from a single endpoint. Change RPC protocol to include device identifier where needed. closes: #15210 * fixes * use ggml_backend_reg_t * address review comments * fix llama-bench backend report * address review comments, change device naming * fix cmd order
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							e29acf74fe
						
					
				
				
					commit
					898acba681
				
			@@ -105,9 +105,12 @@ enum rpc_cmd {
 | 
			
		||||
    RPC_CMD_INIT_TENSOR,
 | 
			
		||||
    RPC_CMD_GET_ALLOC_SIZE,
 | 
			
		||||
    RPC_CMD_HELLO,
 | 
			
		||||
    RPC_CMD_DEVICE_COUNT,
 | 
			
		||||
    RPC_CMD_COUNT,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
 | 
			
		||||
 | 
			
		||||
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
 | 
			
		||||
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
 | 
			
		||||
 | 
			
		||||
@@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp {
 | 
			
		||||
    uint8_t patch;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_device_count_rsp {
 | 
			
		||||
    uint32_t device_count;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_get_alloc_size_req {
 | 
			
		||||
    uint32_t   device;
 | 
			
		||||
    rpc_tensor tensor;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_alloc_buffer_req {
 | 
			
		||||
    uint32_t device;
 | 
			
		||||
    uint64_t size;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp {
 | 
			
		||||
    uint64_t remote_size;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_get_alignment_req {
 | 
			
		||||
    uint32_t device;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_get_alignment_rsp {
 | 
			
		||||
    uint64_t alignment;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_get_max_size_req {
 | 
			
		||||
    uint32_t device;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_get_max_size_rsp {
 | 
			
		||||
    uint64_t max_size;
 | 
			
		||||
};
 | 
			
		||||
@@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp {
 | 
			
		||||
    uint8_t result;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_get_device_memory_req {
 | 
			
		||||
    uint32_t device;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct rpc_msg_get_device_memory_rsp {
 | 
			
		||||
    uint64_t free_mem;
 | 
			
		||||
    uint64_t total_mem;
 | 
			
		||||
@@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() {
 | 
			
		||||
 | 
			
		||||
struct ggml_backend_rpc_buffer_type_context {
 | 
			
		||||
    std::string endpoint;
 | 
			
		||||
    uint32_t    device;
 | 
			
		||||
    std::string name;
 | 
			
		||||
    size_t alignment;
 | 
			
		||||
    size_t max_size;
 | 
			
		||||
    size_t      alignment;
 | 
			
		||||
    size_t      max_size;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct ggml_backend_rpc_context {
 | 
			
		||||
    std::string endpoint;
 | 
			
		||||
    uint32_t    device;
 | 
			
		||||
    std::string name;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -653,7 +676,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
 | 
			
		||||
 | 
			
		||||
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
 | 
			
		||||
    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
 | 
			
		||||
    rpc_msg_alloc_buffer_req request = {size};
 | 
			
		||||
    rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
 | 
			
		||||
    rpc_msg_alloc_buffer_rsp response;
 | 
			
		||||
    auto sock = get_socket(buft_ctx->endpoint);
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
 | 
			
		||||
@@ -669,9 +692,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
 | 
			
		||||
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
 | 
			
		||||
    rpc_msg_get_alignment_req request = {device};
 | 
			
		||||
    rpc_msg_get_alignment_rsp response;
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
 | 
			
		||||
    RPC_STATUS_ASSERT(status);
 | 
			
		||||
    return response.alignment;
 | 
			
		||||
}
 | 
			
		||||
@@ -681,9 +705,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
 | 
			
		||||
    return buft_ctx->alignment;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
 | 
			
		||||
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
 | 
			
		||||
    rpc_msg_get_max_size_req request = {device};
 | 
			
		||||
    rpc_msg_get_max_size_rsp response;
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
 | 
			
		||||
    RPC_STATUS_ASSERT(status);
 | 
			
		||||
    return response.max_size;
 | 
			
		||||
}
 | 
			
		||||
@@ -700,7 +725,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
 | 
			
		||||
        auto sock = get_socket(buft_ctx->endpoint);
 | 
			
		||||
 | 
			
		||||
        rpc_msg_get_alloc_size_req request;
 | 
			
		||||
 | 
			
		||||
        request.device = buft_ctx->device;
 | 
			
		||||
        request.tensor = serialize_tensor(tensor);
 | 
			
		||||
 | 
			
		||||
        rpc_msg_get_alloc_size_rsp response;
 | 
			
		||||
@@ -754,7 +779,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
 | 
			
		||||
    tensors.push_back(serialize_tensor(tensor));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
 | 
			
		||||
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
 | 
			
		||||
    uint32_t n_nodes = cgraph->n_nodes;
 | 
			
		||||
    std::vector<rpc_tensor> tensors;
 | 
			
		||||
    std::unordered_set<ggml_tensor*> visited;
 | 
			
		||||
@@ -762,24 +787,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
 | 
			
		||||
        add_tensor(cgraph->nodes[i], tensors, visited);
 | 
			
		||||
    }
 | 
			
		||||
    // serialization format:
 | 
			
		||||
    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
 | 
			
		||||
    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
 | 
			
		||||
    uint32_t n_tensors = tensors.size();
 | 
			
		||||
    int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
 | 
			
		||||
    int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
 | 
			
		||||
    output.resize(output_size, 0);
 | 
			
		||||
    memcpy(output.data(), &n_nodes, sizeof(n_nodes));
 | 
			
		||||
    uint8_t * dest = output.data();
 | 
			
		||||
    memcpy(dest, &device, sizeof(device));
 | 
			
		||||
    dest += sizeof(device);
 | 
			
		||||
    memcpy(dest, &n_nodes, sizeof(n_nodes));
 | 
			
		||||
    dest += sizeof(n_nodes);
 | 
			
		||||
    for (uint32_t i = 0; i < n_nodes; i++) {
 | 
			
		||||
        memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
 | 
			
		||||
        memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
 | 
			
		||||
    }
 | 
			
		||||
    uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
 | 
			
		||||
    *out_ntensors = n_tensors;
 | 
			
		||||
    rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
 | 
			
		||||
    dest += n_nodes * sizeof(uint64_t);
 | 
			
		||||
    memcpy(dest, &n_tensors, sizeof(n_tensors));
 | 
			
		||||
    dest += sizeof(n_tensors);
 | 
			
		||||
    rpc_tensor * out_tensors = (rpc_tensor *)dest;
 | 
			
		||||
    memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
 | 
			
		||||
    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
 | 
			
		||||
    std::vector<uint8_t> input;
 | 
			
		||||
    serialize_graph(cgraph, input);
 | 
			
		||||
    serialize_graph(rpc_ctx->device, cgraph, input);
 | 
			
		||||
    rpc_msg_graph_compute_rsp response;
 | 
			
		||||
    auto sock = get_socket(rpc_ctx->endpoint);
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
 | 
			
		||||
@@ -804,12 +834,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
 | 
			
		||||
    /* .graph_optimize          = */ NULL,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
 | 
			
		||||
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
 | 
			
		||||
    static std::mutex mutex;
 | 
			
		||||
    std::lock_guard<std::mutex> lock(mutex);
 | 
			
		||||
    std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
 | 
			
		||||
    // NOTE: buffer types are allocated and never freed; this is by design
 | 
			
		||||
    static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
 | 
			
		||||
    auto it = buft_map.find(endpoint);
 | 
			
		||||
    auto it = buft_map.find(buft_name);
 | 
			
		||||
    if (it != buft_map.end()) {
 | 
			
		||||
        return it->second;
 | 
			
		||||
    }
 | 
			
		||||
@@ -818,34 +849,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
 | 
			
		||||
        GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
 | 
			
		||||
        return nullptr;
 | 
			
		||||
    }
 | 
			
		||||
    size_t alignment = get_alignment(sock);
 | 
			
		||||
    size_t max_size = get_max_size(sock);
 | 
			
		||||
    size_t alignment = get_alignment(sock, device);
 | 
			
		||||
    size_t max_size = get_max_size(sock, device);
 | 
			
		||||
    ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
 | 
			
		||||
        /* .endpoint  = */ endpoint,
 | 
			
		||||
        /* .name      = */ "RPC[" + std::string(endpoint) + "]",
 | 
			
		||||
        /* .device    = */ device,
 | 
			
		||||
        /* .name      = */ buft_name,
 | 
			
		||||
        /* .alignment = */ alignment,
 | 
			
		||||
        /* .max_size  = */ max_size
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    auto reg = ggml_backend_rpc_add_server(endpoint);
 | 
			
		||||
    ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
 | 
			
		||||
        /* .iface   = */ ggml_backend_rpc_buffer_type_interface,
 | 
			
		||||
        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
 | 
			
		||||
        /* .device  = */ ggml_backend_reg_dev_get(reg, device),
 | 
			
		||||
        /* .context = */ buft_ctx
 | 
			
		||||
    };
 | 
			
		||||
    buft_map[endpoint] = buft;
 | 
			
		||||
    buft_map[buft_name] = buft;
 | 
			
		||||
    return buft;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
 | 
			
		||||
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
 | 
			
		||||
    std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
 | 
			
		||||
    ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
 | 
			
		||||
        /* .endpoint  = */ endpoint,
 | 
			
		||||
        /* .name      = */ "RPC[" + std::string(endpoint) + "]",
 | 
			
		||||
        /* .endpoint = */ endpoint,
 | 
			
		||||
        /* .device   = */ device,
 | 
			
		||||
        /* .name     = */ dev_name
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    auto reg = ggml_backend_rpc_add_server(endpoint);
 | 
			
		||||
    ggml_backend_t backend = new ggml_backend {
 | 
			
		||||
        /* .guid    = */ ggml_backend_rpc_guid(),
 | 
			
		||||
        /* .iface   = */ ggml_backend_rpc_interface,
 | 
			
		||||
        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
 | 
			
		||||
        /* .device  = */ ggml_backend_reg_dev_get(reg, device),
 | 
			
		||||
        /* .context = */ ctx
 | 
			
		||||
    };
 | 
			
		||||
    return backend;
 | 
			
		||||
@@ -855,37 +889,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
 | 
			
		||||
    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
 | 
			
		||||
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
 | 
			
		||||
    rpc_msg_get_device_memory_req request;
 | 
			
		||||
    request.device = device;
 | 
			
		||||
    rpc_msg_get_device_memory_rsp response;
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
 | 
			
		||||
    RPC_STATUS_ASSERT(status);
 | 
			
		||||
    *free = response.free_mem;
 | 
			
		||||
    *total = response.total_mem;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
 | 
			
		||||
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
 | 
			
		||||
    auto sock = get_socket(endpoint);
 | 
			
		||||
    if (sock == nullptr) {
 | 
			
		||||
        *free = 0;
 | 
			
		||||
        *total = 0;
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    get_device_memory(sock, free, total);
 | 
			
		||||
    get_device_memory(sock, device, free, total);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RPC server-side implementation
 | 
			
		||||
 | 
			
		||||
class rpc_server {
 | 
			
		||||
public:
 | 
			
		||||
    rpc_server(ggml_backend_t backend, const char * cache_dir)
 | 
			
		||||
        : backend(backend), cache_dir(cache_dir) {
 | 
			
		||||
    rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
 | 
			
		||||
        : backends(std::move(backends)), cache_dir(cache_dir) {
 | 
			
		||||
    }
 | 
			
		||||
    ~rpc_server();
 | 
			
		||||
 | 
			
		||||
    void hello(rpc_msg_hello_rsp & response);
 | 
			
		||||
    void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
 | 
			
		||||
    void get_alignment(rpc_msg_get_alignment_rsp & response);
 | 
			
		||||
    void get_max_size(rpc_msg_get_max_size_rsp & response);
 | 
			
		||||
    bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
 | 
			
		||||
    bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
 | 
			
		||||
    bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
 | 
			
		||||
    bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
 | 
			
		||||
    bool free_buffer(const rpc_msg_free_buffer_req & request);
 | 
			
		||||
    bool buffer_clear(const rpc_msg_buffer_clear_req & request);
 | 
			
		||||
@@ -906,7 +942,7 @@ private:
 | 
			
		||||
                              std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    ggml_backend_t backend;
 | 
			
		||||
    std::vector<ggml_backend_t> backends;
 | 
			
		||||
    const char * cache_dir;
 | 
			
		||||
    std::unordered_set<ggml_backend_buffer_t> buffers;
 | 
			
		||||
};
 | 
			
		||||
@@ -919,6 +955,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
 | 
			
		||||
    uint32_t dev_id = request.device;
 | 
			
		||||
    if (dev_id >= backends.size()) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_buffer_type_t buft;
 | 
			
		||||
    struct ggml_init_params params {
 | 
			
		||||
        /*.mem_size   =*/ ggml_tensor_overhead(),
 | 
			
		||||
@@ -935,10 +975,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
 | 
			
		||||
        GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
 | 
			
		||||
    LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
 | 
			
		||||
    if (tensor->buffer == nullptr) {
 | 
			
		||||
        //No buffer allocated.
 | 
			
		||||
        buft = ggml_backend_get_default_buffer_type(backend);
 | 
			
		||||
        buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
 | 
			
		||||
    } else {
 | 
			
		||||
        buft = tensor->buffer->buft;
 | 
			
		||||
    }
 | 
			
		||||
@@ -948,33 +988,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
 | 
			
		||||
    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
 | 
			
		||||
bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
 | 
			
		||||
    uint32_t dev_id = request.device;
 | 
			
		||||
    if (dev_id >= backends.size()) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
 | 
			
		||||
    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
 | 
			
		||||
    response.remote_ptr = 0;
 | 
			
		||||
    response.remote_size = 0;
 | 
			
		||||
    if (buffer != nullptr) {
 | 
			
		||||
        response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
 | 
			
		||||
        response.remote_size = buffer->size;
 | 
			
		||||
        LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
 | 
			
		||||
        LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
 | 
			
		||||
            __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
 | 
			
		||||
        buffers.insert(buffer);
 | 
			
		||||
    } else {
 | 
			
		||||
        LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
 | 
			
		||||
        LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
 | 
			
		||||
    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
 | 
			
		||||
bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
 | 
			
		||||
    uint32_t dev_id = request.device;
 | 
			
		||||
    if (dev_id >= backends.size()) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
 | 
			
		||||
    size_t alignment = ggml_backend_buft_get_alignment(buft);
 | 
			
		||||
    LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
 | 
			
		||||
    LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
 | 
			
		||||
    response.alignment = alignment;
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
 | 
			
		||||
    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
 | 
			
		||||
bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
 | 
			
		||||
    uint32_t dev_id = request.device;
 | 
			
		||||
    if (dev_id >= backends.size()) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
 | 
			
		||||
    size_t max_size = ggml_backend_buft_get_max_size(buft);
 | 
			
		||||
    LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
 | 
			
		||||
    LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
 | 
			
		||||
    response.max_size = max_size;
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
 | 
			
		||||
@@ -1332,23 +1388,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
 | 
			
		||||
 | 
			
		||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
 | 
			
		||||
    // serialization format:
 | 
			
		||||
    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
 | 
			
		||||
    if (input.size() < sizeof(uint32_t)) {
 | 
			
		||||
    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
 | 
			
		||||
    if (input.size() < 2*sizeof(uint32_t)) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    const uint8_t * src = input.data();
 | 
			
		||||
    uint32_t device;
 | 
			
		||||
    memcpy(&device, src, sizeof(device));
 | 
			
		||||
    src += sizeof(device);
 | 
			
		||||
    if (device >= backends.size()) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    uint32_t n_nodes;
 | 
			
		||||
    memcpy(&n_nodes, input.data(), sizeof(n_nodes));
 | 
			
		||||
    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
 | 
			
		||||
    memcpy(&n_nodes, src, sizeof(n_nodes));
 | 
			
		||||
    src += sizeof(n_nodes);
 | 
			
		||||
    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
 | 
			
		||||
    const uint64_t * nodes = (const uint64_t *)src;
 | 
			
		||||
    src += n_nodes*sizeof(uint64_t);
 | 
			
		||||
    uint32_t n_tensors;
 | 
			
		||||
    memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
 | 
			
		||||
    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
 | 
			
		||||
    memcpy(&n_tensors, src, sizeof(n_tensors));
 | 
			
		||||
    src += sizeof(n_tensors);
 | 
			
		||||
    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
 | 
			
		||||
    LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
 | 
			
		||||
    const rpc_tensor * tensors = (const rpc_tensor *)src;
 | 
			
		||||
    LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
 | 
			
		||||
 | 
			
		||||
    size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
 | 
			
		||||
 | 
			
		||||
@@ -1380,7 +1446,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    ggml_status status = ggml_backend_graph_compute(backend, graph);
 | 
			
		||||
    ggml_status status = ggml_backend_graph_compute(backends[device], graph);
 | 
			
		||||
    response.result = status;
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
@@ -1391,9 +1457,9 @@ rpc_server::~rpc_server() {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
 | 
			
		||||
                             sockfd_t sockfd, size_t free_mem, size_t total_mem) {
 | 
			
		||||
    rpc_server server(backend, cache_dir);
 | 
			
		||||
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
 | 
			
		||||
                             sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
 | 
			
		||||
    rpc_server server(backends, cache_dir);
 | 
			
		||||
    uint8_t cmd;
 | 
			
		||||
    if (!recv_data(sockfd, &cmd, 1)) {
 | 
			
		||||
        return;
 | 
			
		||||
@@ -1425,13 +1491,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
 | 
			
		||||
                // HELLO command is handled above
 | 
			
		||||
                return;
 | 
			
		||||
            }
 | 
			
		||||
            case RPC_CMD_DEVICE_COUNT: {
 | 
			
		||||
                if (!recv_msg(sockfd, nullptr, 0)) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                rpc_msg_device_count_rsp response;
 | 
			
		||||
                response.device_count = backends.size();
 | 
			
		||||
                if (!send_msg(sockfd, &response, sizeof(response))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            case RPC_CMD_ALLOC_BUFFER: {
 | 
			
		||||
                rpc_msg_alloc_buffer_req request;
 | 
			
		||||
                if (!recv_msg(sockfd, &request, sizeof(request))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                rpc_msg_alloc_buffer_rsp response;
 | 
			
		||||
                server.alloc_buffer(request, response);
 | 
			
		||||
                if (!server.alloc_buffer(request, response)) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                if (!send_msg(sockfd, &response, sizeof(response))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
@@ -1452,22 +1531,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            case RPC_CMD_GET_ALIGNMENT: {
 | 
			
		||||
                if (!recv_msg(sockfd, nullptr, 0)) {
 | 
			
		||||
                rpc_msg_get_alignment_req request;
 | 
			
		||||
                if (!recv_msg(sockfd, &request, sizeof(request))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                rpc_msg_get_alignment_rsp response;
 | 
			
		||||
                server.get_alignment(response);
 | 
			
		||||
                if (!server.get_alignment(request, response)) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                if (!send_msg(sockfd, &response, sizeof(response))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            case RPC_CMD_GET_MAX_SIZE: {
 | 
			
		||||
                if (!recv_msg(sockfd, nullptr, 0)) {
 | 
			
		||||
                rpc_msg_get_max_size_req request;
 | 
			
		||||
                if (!recv_msg(sockfd, &request, sizeof(request))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                rpc_msg_get_max_size_rsp response;
 | 
			
		||||
                server.get_max_size(response);
 | 
			
		||||
                if (!server.get_max_size(request, response)) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                if (!send_msg(sockfd, &response, sizeof(response))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
@@ -1593,12 +1678,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            case RPC_CMD_GET_DEVICE_MEMORY: {
 | 
			
		||||
                if (!recv_msg(sockfd, nullptr, 0)) {
 | 
			
		||||
                rpc_msg_get_device_memory_req request;
 | 
			
		||||
                if (!recv_msg(sockfd, &request, sizeof(request))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                auto dev_id = request.device;
 | 
			
		||||
                if (dev_id >= backends.size()) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                rpc_msg_get_device_memory_rsp response;
 | 
			
		||||
                response.free_mem = free_mem;
 | 
			
		||||
                response.total_mem = total_mem;
 | 
			
		||||
                response.free_mem = free_mem[dev_id];
 | 
			
		||||
                response.total_mem = total_mem[dev_id];
 | 
			
		||||
                LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
 | 
			
		||||
                    response.free_mem, response.total_mem);
 | 
			
		||||
                if (!send_msg(sockfd, &response, sizeof(response))) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
@@ -1612,16 +1704,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
 | 
			
		||||
                                   const char * cache_dir,
 | 
			
		||||
                                   size_t free_mem, size_t total_mem) {
 | 
			
		||||
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
 | 
			
		||||
                                   size_t n_threads, size_t n_devices,
 | 
			
		||||
                                   ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
 | 
			
		||||
    if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
 | 
			
		||||
        fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<ggml_backend_t> backends;
 | 
			
		||||
    std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
 | 
			
		||||
    std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
 | 
			
		||||
    printf("Starting RPC server v%d.%d.%d\n",
 | 
			
		||||
        RPC_PROTO_MAJOR_VERSION,
 | 
			
		||||
        RPC_PROTO_MINOR_VERSION,
 | 
			
		||||
        RPC_PROTO_PATCH_VERSION);
 | 
			
		||||
    printf("  endpoint       : %s\n", endpoint);
 | 
			
		||||
    printf("  local cache    : %s\n", cache_dir ? cache_dir : "n/a");
 | 
			
		||||
    printf("  backend memory : %zu MB\n", free_mem / (1024 * 1024));
 | 
			
		||||
    printf("Devices:\n");
 | 
			
		||||
    for (size_t i = 0; i < n_devices; i++) {
 | 
			
		||||
        auto dev = devices[i];
 | 
			
		||||
        printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
 | 
			
		||||
               total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
 | 
			
		||||
        auto backend = ggml_backend_dev_init(dev, nullptr);
 | 
			
		||||
        if (!backend) {
 | 
			
		||||
            fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
        backends.push_back(backend);
 | 
			
		||||
        ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
 | 
			
		||||
        if (reg) {
 | 
			
		||||
            auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
 | 
			
		||||
            if (ggml_backend_set_n_threads_fn) {
 | 
			
		||||
                ggml_backend_set_n_threads_fn(backend, n_threads);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string host;
 | 
			
		||||
    int port;
 | 
			
		||||
@@ -1649,22 +1766,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
 | 
			
		||||
            fprintf(stderr, "Failed to accept client connection\n");
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
        printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
 | 
			
		||||
        printf("Accepted client connection\n");
 | 
			
		||||
        fflush(stdout);
 | 
			
		||||
        rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
 | 
			
		||||
        rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
 | 
			
		||||
        printf("Client connection closed\n");
 | 
			
		||||
        fflush(stdout);
 | 
			
		||||
    }
 | 
			
		||||
#ifdef _WIN32
 | 
			
		||||
    WSACleanup();
 | 
			
		||||
#endif
 | 
			
		||||
    for (auto backend : backends) {
 | 
			
		||||
        ggml_backend_free(backend);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// device interface
 | 
			
		||||
 | 
			
		||||
struct ggml_backend_rpc_device_context {
 | 
			
		||||
    std::string endpoint;
 | 
			
		||||
    uint32_t    device;
 | 
			
		||||
    std::string name;
 | 
			
		||||
    std::string description;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
 | 
			
		||||
@@ -1676,15 +1798,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
 | 
			
		||||
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
 | 
			
		||||
    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 | 
			
		||||
 | 
			
		||||
    return ctx->name.c_str();
 | 
			
		||||
    return ctx->description.c_str();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
 | 
			
		||||
    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 | 
			
		||||
 | 
			
		||||
    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
 | 
			
		||||
 | 
			
		||||
    GGML_UNUSED(dev);
 | 
			
		||||
    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
 | 
			
		||||
@@ -1710,7 +1830,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
 | 
			
		||||
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
 | 
			
		||||
    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 | 
			
		||||
 | 
			
		||||
    return ggml_backend_rpc_init(ctx->endpoint.c_str());
 | 
			
		||||
    return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
 | 
			
		||||
 | 
			
		||||
    GGML_UNUSED(params);
 | 
			
		||||
}
 | 
			
		||||
@@ -1718,7 +1838,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
 | 
			
		||||
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
 | 
			
		||||
    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 | 
			
		||||
 | 
			
		||||
    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
 | 
			
		||||
    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
 | 
			
		||||
 | 
			
		||||
    GGML_UNUSED(dev);
 | 
			
		||||
}
 | 
			
		||||
@@ -1736,7 +1856,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
 | 
			
		||||
    ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
 | 
			
		||||
    return buft_ctx->endpoint == dev_ctx->endpoint;
 | 
			
		||||
    return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
 | 
			
		||||
@@ -1759,28 +1879,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
 | 
			
		||||
 | 
			
		||||
// backend reg interface
 | 
			
		||||
 | 
			
		||||
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
 | 
			
		||||
    return "RPC";
 | 
			
		||||
struct ggml_backend_rpc_reg_context {
 | 
			
		||||
    std::string                     name;
 | 
			
		||||
    std::vector<ggml_backend_dev_t> devices;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
    GGML_UNUSED(reg);
 | 
			
		||||
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
 | 
			
		||||
    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
 | 
			
		||||
    return ctx ? ctx->name.c_str() : "RPC";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
 | 
			
		||||
    return 0;
 | 
			
		||||
 | 
			
		||||
    GGML_UNUSED(reg);
 | 
			
		||||
    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
 | 
			
		||||
    return ctx ? ctx->devices.size() : 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
 | 
			
		||||
    GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
 | 
			
		||||
 | 
			
		||||
    GGML_UNUSED(reg);
 | 
			
		||||
    GGML_UNUSED(index);
 | 
			
		||||
    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
 | 
			
		||||
    if (ctx == nullptr) {
 | 
			
		||||
        GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
 | 
			
		||||
    } else {
 | 
			
		||||
        GGML_ASSERT(index < ctx->devices.size());
 | 
			
		||||
        return ctx->devices[index];
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
 | 
			
		||||
    if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
 | 
			
		||||
        return (void *)ggml_backend_rpc_add_device;
 | 
			
		||||
    if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
 | 
			
		||||
        return (void *)ggml_backend_rpc_add_server;
 | 
			
		||||
    }
 | 
			
		||||
    if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
 | 
			
		||||
        return (void *)ggml_backend_rpc_start_server;
 | 
			
		||||
@@ -1807,30 +1933,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
 | 
			
		||||
    return &ggml_backend_rpc_reg;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
 | 
			
		||||
    static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
 | 
			
		||||
 | 
			
		||||
    static std::mutex mutex;
 | 
			
		||||
    std::lock_guard<std::mutex> lock(mutex);
 | 
			
		||||
 | 
			
		||||
    if (dev_map.find(endpoint) != dev_map.end()) {
 | 
			
		||||
        return dev_map[endpoint];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
 | 
			
		||||
        /* .endpoint = */ endpoint,
 | 
			
		||||
        /* .name     = */ "RPC[" + std::string(endpoint) + "]",
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    ggml_backend_dev_t dev = new ggml_backend_device {
 | 
			
		||||
        /* .iface   = */ ggml_backend_rpc_device_i,
 | 
			
		||||
        /* .reg     = */ ggml_backend_rpc_reg(),
 | 
			
		||||
        /* .context = */ ctx,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    dev_map[endpoint] = dev;
 | 
			
		||||
 | 
			
		||||
    return dev;
 | 
			
		||||
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
 | 
			
		||||
    auto sock = get_socket(endpoint);
 | 
			
		||||
    rpc_msg_device_count_rsp response;
 | 
			
		||||
    bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
 | 
			
		||||
    RPC_STATUS_ASSERT(status);
 | 
			
		||||
    return response.device_count;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
 | 
			
		||||
    /* .get_name          = */ ggml_backend_rpc_reg_get_name,
 | 
			
		||||
    /* .get_device_count  = */ ggml_backend_rpc_reg_get_device_count,
 | 
			
		||||
    /* .get_device        = */ ggml_backend_rpc_reg_get_device,
 | 
			
		||||
    /* .get_proc_address  = */ ggml_backend_rpc_get_proc_address,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
 | 
			
		||||
    static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
 | 
			
		||||
    static std::mutex mutex;
 | 
			
		||||
    static uint32_t dev_id = 0;
 | 
			
		||||
    std::lock_guard<std::mutex> lock(mutex);
 | 
			
		||||
    if (reg_map.find(endpoint) != reg_map.end()) {
 | 
			
		||||
        return reg_map[endpoint];
 | 
			
		||||
    }
 | 
			
		||||
    uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
 | 
			
		||||
    if (dev_count == 0) {
 | 
			
		||||
        return nullptr;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
 | 
			
		||||
    ctx->name = "RPC[" + std::string(endpoint) + "]";
 | 
			
		||||
    for (uint32_t ind = 0; ind < dev_count; ind++) {
 | 
			
		||||
        std::string dev_name = "RPC" + std::to_string(dev_id);
 | 
			
		||||
        std::string dev_desc = std::string(endpoint);
 | 
			
		||||
        ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
 | 
			
		||||
            /* .endpoint    = */ endpoint,
 | 
			
		||||
            /* .device      = */ ind,
 | 
			
		||||
            /* .name        = */ dev_name,
 | 
			
		||||
            /* .description = */ dev_desc
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        ggml_backend_dev_t dev = new ggml_backend_device {
 | 
			
		||||
            /* .iface   = */ ggml_backend_rpc_device_i,
 | 
			
		||||
            /* .reg     = */ ggml_backend_rpc_reg(),
 | 
			
		||||
            /* .context = */ dev_ctx,
 | 
			
		||||
        };
 | 
			
		||||
        ctx->devices.push_back(dev);
 | 
			
		||||
        dev_id++;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_reg_t reg = new ggml_backend_reg {
 | 
			
		||||
        /* .api_version = */ GGML_BACKEND_API_VERSION,
 | 
			
		||||
        /* .iface       = */ ggml_backend_rpc_reg_interface,
 | 
			
		||||
        /* .context     = */ ctx
 | 
			
		||||
    };
 | 
			
		||||
    reg_map[endpoint] = reg;
 | 
			
		||||
    return reg;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user