From d5fabe3682de515fd09d6c981f7a0d1b75614455 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 17 Sep 2025 14:33:08 +0800 Subject: [PATCH] CANN: Optimize ggml_cann_set_device (#15935) * CANN: Fix ggml_cann_set_device to avoid redundant device switches - Added a check to skip aclrtSetDevice if the current device is already set. - Prevents unnecessary context switches while keeping thread/device consistency. * CANN: add device default id --- ggml/src/ggml-cann/common.h | 5 ++++- ggml/src/ggml-cann/ggml-cann.cpp | 12 ++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index c5fce8dc91..b707b84359 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -526,7 +526,10 @@ struct ggml_backend_cann_context { */ aclrtStream stream(int stream) { if (streams[stream] == nullptr) { - ggml_cann_set_device(device); + // If the device is not set here, destroying the stream later may cause a mismatch + // between the thread contexts where the stream was created and destroyed. + // However, I printed the device_id, thread_id, and stream, and they are all consistent. + ACL_CHECK(aclrtSetDevice(device)); ACL_CHECK(aclrtCreateStream(&streams[stream])); } return streams[stream]; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 19a18a281d..56d82b4af3 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -75,13 +75,12 @@ * @param device The device ID to set. */ void ggml_cann_set_device(const int32_t device) { - // TODO: uncomment these lines after empty context has fixed. - // int current_device; - // ACL_CHECK(aclrtGetDevice(¤t_device)); + int current_device = -1; + aclrtGetDevice(¤t_device); - // if (device == current_device) { - // return; - // } + if (device == current_device) { + return; + } ACL_CHECK(aclrtSetDevice(device)); } @@ -1729,6 +1728,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_get_rows(ctx, dst); break; case GGML_OP_SET_ROWS: + std::cout << "lcg GGML_OP_SET_ROWS"<< std::endl; ggml_cann_set_rows(ctx, dst); break; case GGML_OP_DUP: