CANN: fix RoPE cache issue on multi-device (#15629)

* CANN: fix RoPE cache issue on multi-device

RoPE cache only needs to be computed once per token.
However, in multi-device scenarios, not every device starts
computation from layer 0, which may lead to unallocated memory
issues and precision errors.

This commit records the first layer of each device to avoid
the above issues.

* CANN: Optimize first-layer detection method

* CANN: Remove trailing whitespace

* CANN: Only cache the data that can be determined as unchanged through the parameters.

* CANN: Update function comment
This commit is contained in:
hipudding
2025-09-01 08:57:00 +08:00
committed by GitHub
parent e92d53b29e
commit 3dc7397a27
3 changed files with 105 additions and 96 deletions

View File

@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
} }
aclTensor* acl_gamma = get_f32_cache_acl_tensor( aclTensor* acl_gamma = get_f32_cache_acl_tensor(
ctx, ctx,
&ctx.f32_one_cache, &ctx.rms_norm_one_tensor_cache.cache,
ctx.f32_one_cache_element, ctx.rms_norm_one_tensor_cache.size,
src->ne, src->ne,
acl_gamma_nb, acl_gamma_nb,
1, // dims 1, // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
} }
aclTensor* acl_rstd = get_f32_cache_acl_tensor( aclTensor* acl_rstd = get_f32_cache_acl_tensor(
ctx, ctx,
&ctx.f32_zero_cache, &ctx.rms_norm_zero_tensor_cache.cache,
ctx.f32_zero_cache_element, ctx.rms_norm_zero_tensor_cache.size,
src->ne, src->ne,
acl_rstd_nb, acl_rstd_nb,
GGML_MAX_DIMS, GGML_MAX_DIMS,
@@ -2248,43 +2248,31 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending * 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled. * on whether @param is_neox is enabled.
* 7. Store the computed values into persistent buffers
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
* *
* @param ctx The CANN backend context, holding memory pool, * @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache. * stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation * @param dst The destination ggml_tensor whose computation
* depends on the cached RoPE values (usually Qcur/Kcur). * depends on the RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values. * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
* @param freq_scale Frequency scaling factor, applied to theta scale. * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
* @param attn_factor Attention scaling factor, applied to sin/cos. * @param theta_scale Scalar exponent base for computing theta scale values.
* @param is_neox Whether to use Neox-style repeat strategy * @param freq_scale Frequency scaling factor, applied to theta scale.
* (dim expansion vs repeat_interleave). * @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
*/ */
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
void* sin_tensor_buffer, void* cos_tensor_buffer,
float theta_scale, float freq_scale, float theta_scale, float freq_scale,
float attn_factor, bool is_neox) { float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on // int sin/cos cache, cache has different repeat method depond on
// @param.is_neox // @param.is_neox
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
// used for accuracy testing
bool is_attention = is_q || is_k;
// just compute in first layer in attention
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_attention && !is_fisrt_layer) {
return;
}
ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors ggml_tensor* src2 = dst->src[2]; // freq_factors
GGML_TENSOR_BINARY_OP_LOCALS int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_length = ne00 / 2;
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
theta_scale_length * sizeof(float_t)}; theta_scale_length * sizeof(float_t)};
@@ -2302,21 +2290,32 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
} }
// init theta scale, just one time // theta_scale arange, [0,1,...,ne00/2 - 1]
if(ctx.rope_init_ptr == nullptr || !is_attention) { aclTensor* acl_theta_scale_tensor = nullptr;
// theta_scale arange, [0,1,...,ne00/2 - 1] // cache theta scale
if(ctx.rope_init_ptr != nullptr){ if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
ACL_CHECK(aclrtFree(ctx.rope_init_ptr)); // theta_scale and freq_scale should not change during the current token inference process,
} // so we can directly use == here instead of comparing the absolute difference.
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); ctx.rope_cache.theta_scale != theta_scale ||
ctx.rope_cache.freq_scale != freq_scale) {
aclTensor* acl_theta_scale_tensor = ctx.rope_cache.theta_scale_length = theta_scale_length;
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t), ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float start = 0; float start = 0;
float step = 1; float step = 1;
float stop = ne00 / 2; float stop = theta_scale_length;
float n_elements = ne00 / 2; float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
// power // power
@@ -2328,34 +2327,29 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
if (freq_scale != 1) { if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true); aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
} }
ggml_cann_release_resources(ctx, acl_theta_scale);
// freq_factors } else {
if (src2) { // use cache
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( acl_theta_scale_tensor =
src2->data, ggml_cann_type_mapping(src2->type), ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
}
// release
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
}
// init sin_repeat && cos_repeat, one token just init in 0 layer
if(position_length > ctx.max_prompt_length) {
ctx.max_prompt_length = position_length;
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
if(ctx.rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float_t));
void* freq_fac_res_ptr = freq_fac_res_allocator.get();
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
src2->data, ggml_cann_type_mapping(src2->type),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
freq_fac_res_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
}
// position // position
aclTensor* acl_position_tensor = ggml_cann_create_tensor( aclTensor* acl_position_tensor = ggml_cann_create_tensor(
@@ -2397,17 +2391,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
} }
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t); sin_reshape_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
} }
aclTensor* acl_sin_repeat_tensor = aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor = aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat // repeat
@@ -2449,6 +2443,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: use ascendc // TODO: use ascendc
// Only test with LLAMA model. // Only test with LLAMA model.
ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1];
// param // param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2481,8 +2476,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
// sin/cos tensor length.
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
void *sin_tensor_buffer = sin_tensor_allocator.get();
void *cos_tensor_buffer = cos_tensor_allocator.get();
// init ctx.rope_cos/rope_sin cache // init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox); aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2491,10 +2494,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
} }
aclTensor* acl_sin_reshape_tensor = aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor = aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_src = ggml_cann_create_tensor(src0);

View File

@@ -360,6 +360,30 @@ struct ggml_cann_graph {
}; };
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if(theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
}
void* theta_scale_cache = nullptr;
int64_t theta_scale_length = 0;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
};
struct ggml_cann_tensor_cache {
~ggml_cann_tensor_cache() {
if(cache != nullptr) {
ACL_CHECK(aclrtFree(cache));
}
}
void* cache = nullptr;
int64_t size = 0;
};
/** /**
* @brief Context for managing CANN backend operations. * @brief Context for managing CANN backend operations.
*/ */
@@ -375,15 +399,11 @@ struct ggml_backend_cann_context {
cann_task_queue task_queue; cann_task_queue task_queue;
bool async_mode; bool async_mode;
// Rope Cache // Rope Cache
void* rope_init_ptr = nullptr; ggml_cann_rope_cache rope_cache;
void* rope_sin_ptr = nullptr;
void* rope_cos_ptr = nullptr;
int64_t max_prompt_length = 0;
// Constant Pool // Constant Pool
void* f32_zero_cache = nullptr; ggml_cann_tensor_cache rms_norm_one_tensor_cache;
void* f32_one_cache = nullptr; ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
int64_t f32_zero_cache_element = 0;
int64_t f32_one_cache_element = 0;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
@@ -415,21 +435,6 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i])); ACL_CHECK(aclrtDestroyStream(streams[i]));
} }
} }
if(rope_init_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_init_ptr));
}
if(rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_sin_ptr));
}
if(rope_cos_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_cos_ptr));
}
if(f32_zero_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_zero_cache));
}
if(f32_one_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_one_cache));
}
} }
/** /**

View File

@@ -2247,6 +2247,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
(ggml_backend_cann_context*)backend->context; (ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device); ggml_cann_set_device(cann_ctx->device);
release_nz_workspace(); release_nz_workspace();
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
bool use_cann_graph = true; bool use_cann_graph = true;
bool cann_graph_update_required = false; bool cann_graph_update_required = false;