CANN: fix RoPE cache issue on multi-device (#15629)

* CANN: fix RoPE cache issue on multi-device

RoPE cache only needs to be computed once per token.
However, in multi-device scenarios, not every device starts
computation from layer 0, which may lead to unallocated memory
issues and precision errors.

This commit records the first layer of each device to avoid
the above issues.

* CANN: Optimize first-layer detection method

* CANN: Remove trailing whitespace

* CANN: Only cache the data that can be determined as unchanged through the parameters.

* CANN: Update function comment
This commit is contained in:
hipudding
2025-09-01 08:57:00 +08:00
committed by GitHub
parent e92d53b29e
commit 3dc7397a27
3 changed files with 105 additions and 96 deletions

View File

@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
ctx,
&ctx.f32_one_cache,
ctx.f32_one_cache_element,
&ctx.rms_norm_one_tensor_cache.cache,
ctx.rms_norm_one_tensor_cache.size,
src->ne,
acl_gamma_nb,
1, // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
ctx,
&ctx.f32_zero_cache,
ctx.f32_zero_cache_element,
&ctx.rms_norm_zero_tensor_cache.cache,
ctx.rms_norm_zero_tensor_cache.size,
src->ne,
acl_rstd_nb,
GGML_MAX_DIMS,
@@ -2248,43 +2248,31 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* 7. Store the computed values into persistent buffers
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
*
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation
* depends on the cached RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation
* depends on the RoPE values (usually Qcur/Kcur).
* @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
* @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
* @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
*/
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
void* sin_tensor_buffer, void* cos_tensor_buffer,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
// @param.is_neox
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
// used for accuracy testing
bool is_attention = is_q || is_k;
// just compute in first layer in attention
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_attention && !is_fisrt_layer) {
return;
}
ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors
GGML_TENSOR_BINARY_OP_LOCALS
int64_t theta_scale_length = ne00 / 2;
int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
theta_scale_length * sizeof(float_t)};
@@ -2302,21 +2290,32 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
}
// init theta scale, just one time
if(ctx.rope_init_ptr == nullptr || !is_attention) {
// theta_scale arange, [0,1,...,ne00/2 - 1]
if(ctx.rope_init_ptr != nullptr){
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
// theta_scale arange, [0,1,...,ne00/2 - 1]
aclTensor* acl_theta_scale_tensor = nullptr;
// cache theta scale
if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
// theta_scale and freq_scale should not change during the current token inference process,
// so we can directly use == here instead of comparing the absolute difference.
ctx.rope_cache.theta_scale != theta_scale ||
ctx.rope_cache.freq_scale != freq_scale) {
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
ctx.rope_cache.theta_scale_length = theta_scale_length;
ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float start = 0;
float step = 1;
float stop = ne00 / 2;
float n_elements = ne00 / 2;
float stop = theta_scale_length;
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
// power
@@ -2328,34 +2327,29 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
}
// freq_factors
if (src2) {
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
src2->data, ggml_cann_type_mapping(src2->type),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
}
// release
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
}
// init sin_repeat && cos_repeat, one token just init in 0 layer
if(position_length > ctx.max_prompt_length) {
ctx.max_prompt_length = position_length;
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
if(ctx.rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_release_resources(ctx, acl_theta_scale);
} else {
// use cache
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float_t));
void* freq_fac_res_ptr = freq_fac_res_allocator.get();
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
src2->data, ggml_cann_type_mapping(src2->type),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
freq_fac_res_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
}
// position
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
@@ -2397,17 +2391,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
}
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat
@@ -2449,6 +2443,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: use ascendc
// Only test with LLAMA model.
ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1];
// param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2481,8 +2476,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
// sin/cos tensor length.
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
void *sin_tensor_buffer = sin_tensor_allocator.get();
void *cos_tensor_buffer = cos_tensor_allocator.get();
// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2491,10 +2494,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_src = ggml_cann_create_tensor(src0);