mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
CANN: Improve device ID handling and aclnnArange checks (#16752)
* cann: improve device ID handling and aclnnArange checks - Stop relying on CANN's internal device ID retrieval; use a global variable instead. - Enforce stricter dimension validation in aclnnArange for better compatibility across CANN versions. * cann: use thread local var
This commit is contained in:
@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
|||||||
ACL_MEM_MALLOC_HUGE_FIRST));
|
ACL_MEM_MALLOC_HUGE_FIRST));
|
||||||
|
|
||||||
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
||||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
theta_scale_ne, theta_scale_nb, 1);
|
||||||
|
|
||||||
float start = 0;
|
float start = 0;
|
||||||
float step = 1;
|
float step = 1;
|
||||||
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
|||||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||||
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
|
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||||
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
|
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
|
||||||
theta_scale_nb, GGML_MAX_DIMS);
|
theta_scale_nb, 1);
|
||||||
float zero_value = 0, one_value = 1;
|
float zero_value = 0, one_value = 1;
|
||||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||||
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
||||||
|
|||||||
@@ -67,19 +67,30 @@
|
|||||||
GGML_ABORT("CANN error");
|
GGML_ABORT("CANN error");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Thread-local variable to record the current device of this thread.
|
||||||
|
thread_local int g_current_cann_device = -1;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sets the device to be used by CANN.
|
* @brief Set the CANN device to be used.
|
||||||
*
|
*
|
||||||
* @param device The device ID to set.
|
* @param device The target device ID to set.
|
||||||
*/
|
*/
|
||||||
void ggml_cann_set_device(const int32_t device) {
|
void ggml_cann_set_device(const int32_t device) {
|
||||||
int current_device = -1;
|
// int current_device = -1;
|
||||||
aclrtGetDevice(¤t_device);
|
// Note: In some CANN versions, if no device has been set yet,
|
||||||
|
// aclrtGetDevice(¤t_device) may return 0 by default.
|
||||||
|
// aclrtGetDevice(¤t_device);
|
||||||
|
|
||||||
if (device == current_device) {
|
// If the current device is already the target one, no need to switch.
|
||||||
|
if (device == g_current_cann_device) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Switch to the new device.
|
||||||
ACL_CHECK(aclrtSetDevice(device));
|
ACL_CHECK(aclrtSetDevice(device));
|
||||||
|
|
||||||
|
// Update the global device record.
|
||||||
|
g_current_cann_device = device;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Reference in New Issue
Block a user