mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	rpc : check src buffer when copying tensor (#16421)
Only dst buffer is guaranteed to be an RPC buffer. Add check for the src one.
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							898acba681
						
					
				
				
					commit
					f39283960b
				
			@@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
 | 
			
		||||
    RPC_STATUS_ASSERT(status);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
 | 
			
		||||
    return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
 | 
			
		||||
    // check if src and dst are on the same server
 | 
			
		||||
    ggml_backend_buffer_t src_buffer = src->buffer;
 | 
			
		||||
    ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
 | 
			
		||||
    ggml_backend_buffer_t dst_buffer = dst->buffer;
 | 
			
		||||
    ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
 | 
			
		||||
    if (src_ctx->sock != dst_ctx->sock) {
 | 
			
		||||
        return false;
 | 
			
		||||
    if (ggml_backend_buffer_is_rpc(src->buffer)) {
 | 
			
		||||
        // check if src and dst are on the same server
 | 
			
		||||
        ggml_backend_buffer_t src_buffer = src->buffer;
 | 
			
		||||
        ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
 | 
			
		||||
        ggml_backend_buffer_t dst_buffer = dst->buffer;
 | 
			
		||||
        ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
 | 
			
		||||
        if (src_ctx->sock != dst_ctx->sock) {
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
        ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 | 
			
		||||
        rpc_msg_copy_tensor_req request;
 | 
			
		||||
        request.src = serialize_tensor(src);
 | 
			
		||||
        request.dst = serialize_tensor(dst);
 | 
			
		||||
        rpc_msg_copy_tensor_rsp response;
 | 
			
		||||
        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
 | 
			
		||||
        RPC_STATUS_ASSERT(status);
 | 
			
		||||
        return response.result;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
 | 
			
		||||
    rpc_msg_copy_tensor_req request;
 | 
			
		||||
    request.src = serialize_tensor(src);
 | 
			
		||||
    request.dst = serialize_tensor(dst);
 | 
			
		||||
    rpc_msg_copy_tensor_rsp response;
 | 
			
		||||
    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
 | 
			
		||||
    RPC_STATUS_ASSERT(status);
 | 
			
		||||
    return response.result;
 | 
			
		||||
    return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user