CANN: Optimize ggml_cann_set_device (#15935)

* CANN: Fix ggml_cann_set_device to avoid redundant device switches

- Added a check to skip aclrtSetDevice if the current device is already set.
- Prevents unnecessary context switches while keeping thread/device consistency.

* CANN: add device default id
This commit is contained in:
Chenguang Li
2025-09-17 14:33:08 +08:00
committed by GitHub
parent 8ff206097c
commit d5fabe3682
2 changed files with 10 additions and 7 deletions

View File

@@ -526,7 +526,10 @@ struct ggml_backend_cann_context {
*/ */
aclrtStream stream(int stream) { aclrtStream stream(int stream) {
if (streams[stream] == nullptr) { if (streams[stream] == nullptr) {
ggml_cann_set_device(device); // If the device is not set here, destroying the stream later may cause a mismatch
// between the thread contexts where the stream was created and destroyed.
// However, I printed the device_id, thread_id, and stream, and they are all consistent.
ACL_CHECK(aclrtSetDevice(device));
ACL_CHECK(aclrtCreateStream(&streams[stream])); ACL_CHECK(aclrtCreateStream(&streams[stream]));
} }
return streams[stream]; return streams[stream];

View File

@@ -75,13 +75,12 @@
* @param device The device ID to set. * @param device The device ID to set.
*/ */
void ggml_cann_set_device(const int32_t device) { void ggml_cann_set_device(const int32_t device) {
// TODO: uncomment these lines after empty context has fixed. int current_device = -1;
// int current_device; aclrtGetDevice(&current_device);
// ACL_CHECK(aclrtGetDevice(&current_device));
// if (device == current_device) { if (device == current_device) {
// return; return;
// } }
ACL_CHECK(aclrtSetDevice(device)); ACL_CHECK(aclrtSetDevice(device));
} }
@@ -1729,6 +1728,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_get_rows(ctx, dst); ggml_cann_get_rows(ctx, dst);
break; break;
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
std::cout << "lcg GGML_OP_SET_ROWS"<< std::endl;
ggml_cann_set_rows(ctx, dst); ggml_cann_set_rows(ctx, dst);
break; break;
case GGML_OP_DUP: case GGML_OP_DUP: