CANN: Refactor ND to NZ workspace to be per-device (#15763)

* CANN:Refactor ND to NZ workspace to be per-device in Ascend backend

- Replaced the previous single global ND→NZ workspace with a per-device
  cache using unordered_map keyed by device ID.
- Functions `release_nz_workspace`, `relloc_nz_workspace`, and
  `get_nz_workspace` now manage workspace independently for each device,
  preventing memory conflicts in multi-device / pipeline parallel scenarios.
- This change fixes potential precision issues caused by workspace
  overwrites when multiple devices perform ND→NZ conversions concurrently.

Co-authored-by: hipudding <huafengchun@gmail.com>

* refactor

Signed-off-by: noemotiovon <757486878@qq.com>

* rename

Signed-off-by: noemotiovon <757486878@qq.com>

* fix review comments

Signed-off-by: noemotiovon <757486878@qq.com>

---------

Signed-off-by: noemotiovon <757486878@qq.com>
Co-authored-by: hipudding <huafengchun@gmail.com>
This commit is contained in:
Chenguang Li
2025-09-04 20:20:14 +08:00
committed by GitHub
parent a68d914426
commit c1c354e44c

View File

@@ -1116,30 +1116,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed /**
namespace { * @brief Workspace for caching NZ buffers per device.
void* g_nz_workspace = nullptr; *
size_t g_nz_workspace_allocated = 0; * This struct manages a device buffer used in NZ computations. It supports
* allocation, reallocation, and clearing of cached memory. The struct is
* designed to be used with a global array, one per device.
*/
struct ggml_cann_nz_workspace {
void* ptr; // Pointer to allocated device buffer
size_t allocated; // Size of currently allocated buffer in bytes
void release_nz_workspace() { /**
if (g_nz_workspace) { * @brief Constructor. Initializes the workspace with no allocated memory.
aclrtFree(g_nz_workspace); */
g_nz_workspace = nullptr; ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
g_nz_workspace_allocated = 0;
/**
* @brief Free cached memory and reset the workspace.
*
* If a buffer has been allocated, this function releases it using
* aclrtFree and resets internal state.
*/
void clear() {
if (ptr) {
ACL_CHECK(aclrtFree(ptr));
ptr = nullptr;
allocated = 0;
} }
} }
void relloc_nz_workspace(size_t new_size) { /**
if (new_size > g_nz_workspace_allocated) { * @brief Allocate or reallocate the workspace buffer.
if (g_nz_workspace) { *
aclrtFree(g_nz_workspace); * If the requested size is larger than the currently allocated size,
g_nz_workspace = nullptr; * the old buffer will be freed and a new buffer of the requested size
* will be allocated on the device.
*
* @param new_size Size in bytes to allocate for the workspace.
*/
void realloc(size_t new_size) {
if (new_size > allocated) {
clear();
ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
allocated = new_size;
} }
ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
g_nz_workspace_allocated = new_size;
} }
}
} /**
* @brief Get the device buffer pointer.
*
* @return Pointer to the allocated buffer, or nullptr if not allocated.
*/
void* get() const { return ptr; }
};
/**
* @brief Global array of NZ workspaces, one per device.
*/
static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
/** /**
* @brief Convert tensor weights to NZ format using Ascend CANN API. * @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1149,13 +1184,13 @@ namespace {
* improve performance on certain hardware. * improve performance on certain hardware.
* *
* @param tensor Pointer to the input ggml_tensor containing the weights. * @param tensor Pointer to the input ggml_tensor containing the weights.
* @param data Pointer to the raw data buffer for the tensor weights.
* @param offset Byte offset within the tensor data buffer where weights start. * @param offset Byte offset within the tensor data buffer where weights start.
* @param device device id.
* *
* @note The workspace buffer used in this function is managed globally and reused * @note The workspace buffer used in this function is managed globally and reused
* across calls. This reduces overhead from repeated memory allocation and deallocation. * across calls. This reduces overhead from repeated memory allocation and deallocation.
*/ */
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) { static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
tensor->nb, 2, ACL_FORMAT_ND, offset); tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0; uint64_t workspaceSize = 0;
@@ -1165,7 +1200,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
&workspaceSize, &executor)); &workspaceSize, &executor));
// Avoid frequent malloc/free of the workspace. // Avoid frequent malloc/free of the workspace.
relloc_nz_workspace(workspaceSize); g_nz_workspaces[device].realloc(workspaceSize);
void* g_nz_workspace = g_nz_workspaces[device].get();
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr)); ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
ACL_CHECK(aclDestroyTensor(weightTransposed)); ACL_CHECK(aclDestroyTensor(weightTransposed));
@@ -1203,7 +1240,7 @@ static void ggml_backend_cann_buffer_set_tensor(
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) { if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1); GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset); weight_format_to_nz(tensor, offset, ctx->device);
} }
} else { } else {
void *transform_buffer = malloc(size); void *transform_buffer = malloc(size);
@@ -2262,7 +2299,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
ggml_backend_cann_context* cann_ctx = ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context; (ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device); ggml_cann_set_device(cann_ctx->device);
release_nz_workspace(); g_nz_workspaces[cann_ctx->device].clear();
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
bool use_cann_graph = true; bool use_cann_graph = true;