diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index c5fce8dc91..b707b84359 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -526,7 +526,10 @@ struct ggml_backend_cann_context { */ aclrtStream stream(int stream) { if (streams[stream] == nullptr) { - ggml_cann_set_device(device); + // If the device is not set here, destroying the stream later may cause a mismatch + // between the thread contexts where the stream was created and destroyed. + // However, I printed the device_id, thread_id, and stream, and they are all consistent. + ACL_CHECK(aclrtSetDevice(device)); ACL_CHECK(aclrtCreateStream(&streams[stream])); } return streams[stream]; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 19a18a281d..56d82b4af3 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -75,13 +75,12 @@ * @param device The device ID to set. */ void ggml_cann_set_device(const int32_t device) { - // TODO: uncomment these lines after empty context has fixed. - // int current_device; - // ACL_CHECK(aclrtGetDevice(¤t_device)); + int current_device = -1; + aclrtGetDevice(¤t_device); - // if (device == current_device) { - // return; - // } + if (device == current_device) { + return; + } ACL_CHECK(aclrtSetDevice(device)); } @@ -1729,6 +1728,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_get_rows(ctx, dst); break; case GGML_OP_SET_ROWS: + std::cout << "lcg GGML_OP_SET_ROWS"<< std::endl; ggml_cann_set_rows(ctx, dst); break; case GGML_OP_DUP: