mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : add more clean up when cancel_tasks is called (#11340)
* server : add more clean up when cancel_tasks is called * fix recv_with_timeout * std::remove_if * fix std::remove_if
This commit is contained in:
		| @@ -1433,6 +1433,10 @@ struct server_queue { | |||||||
|         } else { |         } else { | ||||||
|             queue_tasks.push_back(std::move(task)); |             queue_tasks.push_back(std::move(task)); | ||||||
|         } |         } | ||||||
|  |         // if this is cancel task make sure to clean up pending tasks | ||||||
|  |         if (task.type == SERVER_TASK_TYPE_CANCEL) { | ||||||
|  |             cleanup_pending_task(task.id_target); | ||||||
|  |         } | ||||||
|         condition_tasks.notify_one(); |         condition_tasks.notify_one(); | ||||||
|         return task.id; |         return task.id; | ||||||
|     } |     } | ||||||
| @@ -1450,6 +1454,10 @@ struct server_queue { | |||||||
|             } else { |             } else { | ||||||
|                 queue_tasks.push_back(std::move(task)); |                 queue_tasks.push_back(std::move(task)); | ||||||
|             } |             } | ||||||
|  |             // if this is cancel task make sure to clean up pending tasks | ||||||
|  |             if (task.type == SERVER_TASK_TYPE_CANCEL) { | ||||||
|  |                 cleanup_pending_task(task.id_target); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|         condition_tasks.notify_one(); |         condition_tasks.notify_one(); | ||||||
|         return 0; |         return 0; | ||||||
| @@ -1544,6 +1552,20 @@ struct server_queue { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | private: | ||||||
|  |     void cleanup_pending_task(int id_task) { | ||||||
|  |         // no need lock because this is called exclusively by post() | ||||||
|  |         auto rm_func = [id_task](const server_task & task) { | ||||||
|  |             return task.id_target == id_task; | ||||||
|  |         }; | ||||||
|  |         queue_tasks.erase( | ||||||
|  |             std::remove_if(queue_tasks.begin(),          queue_tasks.end(),          rm_func), | ||||||
|  |             queue_tasks.end()); | ||||||
|  |         queue_tasks_deferred.erase( | ||||||
|  |             std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), | ||||||
|  |             queue_tasks_deferred.end()); | ||||||
|  |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct server_response { | struct server_response { | ||||||
| @@ -1579,6 +1601,12 @@ struct server_response { | |||||||
|  |  | ||||||
|         std::unique_lock<std::mutex> lock(mutex_results); |         std::unique_lock<std::mutex> lock(mutex_results); | ||||||
|         waiting_task_ids.erase(id_task); |         waiting_task_ids.erase(id_task); | ||||||
|  |         // make sure to clean up all pending results | ||||||
|  |         queue_results.erase( | ||||||
|  |             std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { | ||||||
|  |                 return res->id == id_task; | ||||||
|  |             }), | ||||||
|  |             queue_results.end()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) { |     void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) { | ||||||
| @@ -1598,7 +1626,7 @@ struct server_response { | |||||||
|                 return !queue_results.empty(); |                 return !queue_results.empty(); | ||||||
|             }); |             }); | ||||||
|  |  | ||||||
|             for (int i = 0; i < (int) queue_results.size(); i++) { |             for (size_t i = 0; i < queue_results.size(); i++) { | ||||||
|                 if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { |                 if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { | ||||||
|                     server_task_result_ptr res = std::move(queue_results[i]); |                     server_task_result_ptr res = std::move(queue_results[i]); | ||||||
|                     queue_results.erase(queue_results.begin() + i); |                     queue_results.erase(queue_results.begin() + i); | ||||||
| @@ -1615,12 +1643,6 @@ struct server_response { | |||||||
|     server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) { |     server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) { | ||||||
|         while (true) { |         while (true) { | ||||||
|             std::unique_lock<std::mutex> lock(mutex_results); |             std::unique_lock<std::mutex> lock(mutex_results); | ||||||
|             bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{ |  | ||||||
|                 return !queue_results.empty(); |  | ||||||
|             }); |  | ||||||
|             if (!cr_res) { |  | ||||||
|                 return nullptr; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             for (int i = 0; i < (int) queue_results.size(); i++) { |             for (int i = 0; i < (int) queue_results.size(); i++) { | ||||||
|                 if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { |                 if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { | ||||||
| @@ -1629,6 +1651,11 @@ struct server_response { | |||||||
|                     return res; |                     return res; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|  |             std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); | ||||||
|  |             if (cr_res == std::cv_status::timeout) { | ||||||
|  |                 return nullptr; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // should never reach here |         // should never reach here | ||||||
| @@ -2376,8 +2403,8 @@ struct server_context { | |||||||
|  |  | ||||||
|             server_task task(SERVER_TASK_TYPE_CANCEL); |             server_task task(SERVER_TASK_TYPE_CANCEL); | ||||||
|             task.id_target = id_task; |             task.id_target = id_task; | ||||||
|             cancel_tasks.push_back(task); |  | ||||||
|             queue_results.remove_waiting_task_id(id_task); |             queue_results.remove_waiting_task_id(id_task); | ||||||
|  |             cancel_tasks.push_back(task); | ||||||
|         } |         } | ||||||
|         // push to beginning of the queue, so it has highest priority |         // push to beginning of the queue, so it has highest priority | ||||||
|         queue_tasks.post(cancel_tasks, true); |         queue_tasks.post(cancel_tasks, true); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen