mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
CANN: Add ROPE sin/cos cache for reuse (#15912)
* CANN: Add ROPE sin/cos cache for reuse Introduce sin/cos caching mechanism in ROPE to avoid redundant computation across layers. The cache is built on the first layer per device and reused by subsequent layers if parameters match. - Added sin_cache / cos_cache pointers and position_length tracking - Introduced cache validity flags and properties: (ext_factor, theta_scale, freq_scale, attn_factor, is_neox) - Accelerates ROPE by eliminating repeated sin/cos generation This change reduces overhead in multi-layer scenarios while preserving correctness by verifying parameter consistency. Co-authored-by: hipudding <huafengchun@gmail.com> * fix typo Signed-off-by: noemotiovon <757486878@qq.com> --------- Signed-off-by: noemotiovon <757486878@qq.com> Co-authored-by: hipudding <huafengchun@gmail.com>
This commit is contained in:
@@ -2268,8 +2268,6 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
|
|||||||
* stream, and persistent buffers for rope init/cache.
|
* stream, and persistent buffers for rope init/cache.
|
||||||
* @param dst The destination ggml_tensor whose computation
|
* @param dst The destination ggml_tensor whose computation
|
||||||
* depends on the RoPE values (usually Qcur/Kcur).
|
* depends on the RoPE values (usually Qcur/Kcur).
|
||||||
* @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
|
|
||||||
* @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
|
|
||||||
* @param theta_scale Scalar exponent base for computing theta scale values.
|
* @param theta_scale Scalar exponent base for computing theta scale values.
|
||||||
* @param freq_scale Frequency scaling factor, applied to theta scale.
|
* @param freq_scale Frequency scaling factor, applied to theta scale.
|
||||||
* @param attn_factor Attention scaling factor, applied to sin/cos.
|
* @param attn_factor Attention scaling factor, applied to sin/cos.
|
||||||
@@ -2277,17 +2275,23 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
|
|||||||
* (dim expansion vs repeat_interleave).
|
* (dim expansion vs repeat_interleave).
|
||||||
*/
|
*/
|
||||||
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||||
void* sin_tensor_buffer, void* cos_tensor_buffer,
|
|
||||||
float* corr_dims, float ext_factor,
|
float* corr_dims, float ext_factor,
|
||||||
float theta_scale, float freq_scale,
|
float theta_scale, float freq_scale,
|
||||||
float attn_factor, bool is_neox) {
|
float attn_factor, bool is_neox) {
|
||||||
// int sin/cos cache, cache has different repeat method depond on
|
|
||||||
// @param.is_neox
|
|
||||||
|
|
||||||
ggml_tensor* src0 = dst->src[0]; // input
|
ggml_tensor* src0 = dst->src[0]; // input
|
||||||
ggml_tensor* src1 = dst->src[1]; // position
|
ggml_tensor* src1 = dst->src[1]; // position
|
||||||
ggml_tensor* src2 = dst->src[2]; // freq_factors
|
ggml_tensor* src2 = dst->src[2]; // freq_factors
|
||||||
|
|
||||||
|
if(src2 == nullptr && ctx.rope_cache.cached
|
||||||
|
&& ctx.rope_cache.ext_factor == ext_factor
|
||||||
|
&& ctx.rope_cache.theta_scale == theta_scale
|
||||||
|
&& ctx.rope_cache.freq_scale == freq_scale
|
||||||
|
&& ctx.rope_cache.attn_factor == attn_factor
|
||||||
|
&& ctx.rope_cache.is_neox == is_neox) {
|
||||||
|
// use cache.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t theta_scale_length = src0->ne[0] / 2;
|
int64_t theta_scale_length = src0->ne[0] / 2;
|
||||||
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
|
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
|
||||||
size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
|
size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
|
||||||
@@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
|||||||
ctx.rope_cache.freq_scale != freq_scale) {
|
ctx.rope_cache.freq_scale != freq_scale) {
|
||||||
|
|
||||||
ctx.rope_cache.theta_scale_length = theta_scale_length;
|
ctx.rope_cache.theta_scale_length = theta_scale_length;
|
||||||
ctx.rope_cache.theta_scale = theta_scale;
|
|
||||||
ctx.rope_cache.freq_scale = freq_scale;
|
|
||||||
|
|
||||||
if (ctx.rope_cache.theta_scale_cache != nullptr) {
|
if (ctx.rope_cache.theta_scale_cache != nullptr) {
|
||||||
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
|
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
|
||||||
@@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
|||||||
// return MIN(1, MAX(0, y)) - 1;
|
// return MIN(1, MAX(0, y)) - 1;
|
||||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||||
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
|
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||||
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
|
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float),
|
||||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||||
float zero_value = 0, one_value = 1;
|
float zero_value = 0, one_value = 1;
|
||||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||||
@@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
|||||||
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
|
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
|
||||||
|
if (position_length > ctx.rope_cache.position_length) {
|
||||||
|
ctx.rope_cache.position_length = position_length;
|
||||||
|
if (ctx.rope_cache.sin_cache != nullptr) {
|
||||||
|
ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));
|
||||||
|
}
|
||||||
|
if (ctx.rope_cache.cos_cache != nullptr) {
|
||||||
|
ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));
|
||||||
|
}
|
||||||
|
int64_t repeat_theta_length = theta_scale_length * position_length * 2;
|
||||||
|
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||||
|
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||||
|
}
|
||||||
|
|
||||||
// position
|
// position
|
||||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||||
src1->data, ggml_cann_type_mapping(src1->type),
|
src1->data, ggml_cann_type_mapping(src1->type),
|
||||||
@@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
|||||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||||
}
|
}
|
||||||
aclTensor* acl_sin_repeat_tensor =
|
aclTensor* acl_sin_repeat_tensor =
|
||||||
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
|
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
|
||||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||||
aclTensor* acl_cos_repeat_tensor =
|
aclTensor* acl_cos_repeat_tensor =
|
||||||
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
|
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
|
||||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||||
|
|
||||||
// repeat
|
// repeat
|
||||||
@@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
|||||||
num_repeats, output_size);
|
num_repeats, output_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Other layers use cache except first layer.
|
||||||
|
ctx.rope_cache.cached = true;
|
||||||
|
ctx.rope_cache.ext_factor = ext_factor;
|
||||||
|
ctx.rope_cache.theta_scale = theta_scale;
|
||||||
|
ctx.rope_cache.freq_scale = freq_scale;
|
||||||
|
ctx.rope_cache.attn_factor = attn_factor;
|
||||||
|
ctx.rope_cache.is_neox = is_neox;
|
||||||
|
|
||||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||||
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
|
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
|
||||||
acl_cos_repeat_tensor);
|
acl_cos_repeat_tensor);
|
||||||
@@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||||
// TODO: use ascendc
|
|
||||||
// Only test with LLAMA model.
|
|
||||||
ggml_tensor* src0 = dst->src[0]; // input
|
ggml_tensor* src0 = dst->src[0]; // input
|
||||||
ggml_tensor* src1 = dst->src[1];
|
|
||||||
|
|
||||||
// param
|
// param
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
@@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
|
|
||||||
// sin/cos tensor length.
|
|
||||||
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
|
|
||||||
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
|
|
||||||
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
|
|
||||||
void *sin_tensor_buffer = sin_tensor_allocator.get();
|
|
||||||
void *cos_tensor_buffer = cos_tensor_allocator.get();
|
|
||||||
|
|
||||||
// init ctx.rope_cos/rope_sin cache
|
// init ctx.rope_cos/rope_sin cache
|
||||||
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
|
aclnn_cache_init(ctx, dst, corr_dims, ext_factor,
|
||||||
theta_scale, freq_scale, attn_factor, is_neox);
|
theta_scale, freq_scale, attn_factor, is_neox);
|
||||||
|
|
||||||
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
||||||
@@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||||
}
|
}
|
||||||
aclTensor* acl_sin_reshape_tensor =
|
aclTensor* acl_sin_reshape_tensor =
|
||||||
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
|
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
|
||||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||||
aclTensor* acl_cos_reshape_tensor =
|
aclTensor* acl_cos_reshape_tensor =
|
||||||
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
|
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
|
||||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||||
|
|
||||||
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
||||||
|
|||||||
@@ -425,12 +425,27 @@ struct ggml_cann_rope_cache {
|
|||||||
if(theta_scale_cache != nullptr) {
|
if(theta_scale_cache != nullptr) {
|
||||||
ACL_CHECK(aclrtFree(theta_scale_cache));
|
ACL_CHECK(aclrtFree(theta_scale_cache));
|
||||||
}
|
}
|
||||||
|
if(sin_cache != nullptr) {
|
||||||
|
ACL_CHECK(aclrtFree(sin_cache));
|
||||||
|
}
|
||||||
|
if(cos_cache != nullptr) {
|
||||||
|
ACL_CHECK(aclrtFree(cos_cache));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* theta_scale_cache = nullptr;
|
void* theta_scale_cache = nullptr;
|
||||||
int64_t theta_scale_length = 0;
|
int64_t theta_scale_length = 0;
|
||||||
|
// sin/cos cache, used only to accelerate first layer on each device
|
||||||
|
void* sin_cache = nullptr;
|
||||||
|
void* cos_cache = nullptr;
|
||||||
|
int64_t position_length = 0;
|
||||||
|
// Properties to check before reusing the sincos cache
|
||||||
|
bool cached = false;
|
||||||
|
float ext_factor = 0.0f;
|
||||||
float theta_scale = 0.0f;
|
float theta_scale = 0.0f;
|
||||||
float freq_scale = 0.0f;
|
float freq_scale = 0.0f;
|
||||||
|
float attn_factor = 0.0f;
|
||||||
|
bool is_neox = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_cann_tensor_cache {
|
struct ggml_cann_tensor_cache {
|
||||||
|
|||||||
@@ -2353,6 +2353,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
|||||||
ggml_cann_set_device(cann_ctx->device);
|
ggml_cann_set_device(cann_ctx->device);
|
||||||
g_nz_workspaces[cann_ctx->device].clear();
|
g_nz_workspaces[cann_ctx->device].clear();
|
||||||
|
|
||||||
|
// calculate rope cache for fist layer in current device.
|
||||||
|
cann_ctx->rope_cache.cached = false;
|
||||||
|
|
||||||
#ifdef USE_ACL_GRAPH
|
#ifdef USE_ACL_GRAPH
|
||||||
bool use_cann_graph = true;
|
bool use_cann_graph = true;
|
||||||
bool cann_graph_update_required = false;
|
bool cann_graph_update_required = false;
|
||||||
|
|||||||
Reference in New Issue
Block a user