12 Jul, 2024
It is kind of a pain to get Python's multiprocessing module to shutdown correctly on kill -2 PID
and kill PID
.
Here's how you do it using normal signals, but there is a better way I'll show next:
import multiprocessing
import os
import asyncio
import time
import signal
import sys
def worker():
def sigterm_handler(signum, frame):
print(f"Worker process {os.getpid()} received SIGTERM")
sys.exit(0) # Exit gracefully
signal.signal(signal.SIGTERM, sigterm_handler)
signal.signal(signal.SIGINT, sigterm_handler)
print(f"Worker process started with PID: {os.getpid()}")
while True:
print("Working...")
time.sleep(1)
async def main():
print(f"Started {os.getpid()}")
loop = asyncio.get_running_loop()
# Create and start the worker process
process = multiprocessing.Process(target=worker)
process.start()
print(f"Started worker, now we are {os.getpid()}")
def sigterm_handler(signum, frame):
if process.is_alive():
process.terminate()
process.join()
print("Process was alive")
time.sleep(1)
else:
print("Process is not alive")
print("Event loop has been shut down")
sys.exit(0)
signal.signal(signal.SIGTERM, sigterm_handler)
signal.signal(signal.SIGINT, sigterm_handler)
await asyncio.sleep(100)
if __name__ == '__main__':
asyncio.run(main())
Now you probably want to gracefully shutdown all the tasks too.
The asyncio
module can handle signals as part of the event loop.
Here's the most sophisticated example I've come up with (ignoring multiprocessing for now):
import asyncio
import signal
async def looping_task(task_num):
try:
while True:
print(f'{task_num}:in looping_task')
await asyncio.sleep(5.0)
except asyncio.CancelledError:
print(f"{task_num}: I was cancelled!")
return f"{task_num}: I was cancelled!"
async def handle_client(reader, writer):
try:
while True:
data = await reader.read(100)
if not data:
break
print(f"Received {data!r} from {writer.get_extra_info('peername')}")
writer.write(data)
await writer.drain()
except asyncio.CancelledError:
print("Connection handler was cancelled")
finally:
writer.close()
await writer.wait_closed()
async def shutdown(loop, main_task, stop_event, server, server_tasks):
print("Received exit signal, shutting down...")
server.close()
async def server_shutdown():
# Wait for the server to close with a timeout of 1 second
try:
await asyncio.wait_for(server.wait_closed(), timeout=1.0)
return 'Client connections closed'
except asyncio.TimeoutError:
print("Server did not close within 1 second, forcing shutdown...")
for task in server_tasks:
task.cancel()
results = await asyncio.gather(*server_tasks, return_exceptions=True)
return 'Client connections terminated'
# Cancel all tasks except the main task, the current task, and server tasks
tasks = [t for t in asyncio.all_tasks(loop) if t is not main_task and t is not asyncio.current_task() and t not in server_tasks]
for task in tasks:
task.cancel()
# Wait for all tasks to finish
results = await asyncio.gather(*tasks, server_shutdown(), return_exceptions=True)
print(f"Finished awaiting cancelled tasks, results: {results}")
stop_event.set()
def add_signal_handlers(loop, main_task, stop_event, server, server_tasks):
for sig in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(
sig,
lambda: asyncio.create_task(shutdown(loop, main_task, stop_event, server, server_tasks))
)
async def main():
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()
server_tasks = set()
for i in range(5):
asyncio.create_task(looping_task(i))
async def handle_client_wrapper(reader, writer):
task = asyncio.current_task()
server_tasks.add(task)
try:
await handle_client(reader, writer)
finally:
server_tasks.remove(task)
server = await asyncio.start_server(handle_client_wrapper, '127.0.0.1', 8888)
print("Serving on", server.sockets[0].getsockname())
add_signal_handlers(loop, asyncio.current_task(), stop_event, server, server_tasks)
try:
await stop_event.wait() # Wait for the stop_event to be set
except asyncio.CancelledError:
print("Main task cancelled")
print("Cleaning up before exiting...")
print('Shutdown complete')
if __name__ == "__main__":
asyncio.run(main())
print("Event loop has closed")
Here's what you get without any active connections (there is no 1 second wait):
% python3 event12.py
0:in looping_task
1:in looping_task
2:in looping_task
3:in looping_task
4:in looping_task
Serving on ('127.0.0.1', 8888)
^CReceived exit signal, shutting down...
1: I was cancelled!
2: I was cancelled!
3: I was cancelled!
0: I was cancelled!
4: I was cancelled!
Finished awaiting cancelled tasks, results: ['1: I was cancelled!', '2: I was cancelled!', '3: I was cancelled!', '0: I was cancelled!', '4: I was cancelled!', 'Client connections closed']
Cleaning up before exiting...
Shutdown complete
Event loop has closed
And here's what you get with a connection (made with nc 127.0.0.1 8888
in another terminal):
0:in looping_task
1:in looping_task
2:in looping_task
3:in looping_task
4:in looping_task
Serving on ('127.0.0.1', 8888)
Received b'asd\n' from ('127.0.0.1', 50531)
^CReceived exit signal, shutting down...
1: I was cancelled!
2: I was cancelled!
3: I was cancelled!
0: I was cancelled!
4: I was cancelled!
Server did not close within 1 second, forcing shutdown...
Connection handler was cancelled
Finished awaiting cancelled tasks, results: ['1: I was cancelled!', '2: I was cancelled!', '3: I was cancelled!', '0: I was cancelled!', '4: I was cancelled!', 'Client connections terminated']
Cleaning up before exiting...
Shutdown complete
Event loop has closed
When putting multiprocessing back in I often want the initial process to also run the same code as the worker, which slightly complicates things.
Here's an example:
import asyncio
import signal
import socket
from multiprocessing import Process, get_context, current_process
async def looping_task(task_num):
try:
while True:
print(f"[{current_process().name}] {task_num}: in looping_task")
await asyncio.sleep(10)
except asyncio.CancelledError:
print(f"[{current_process().name}] {task_num}: I was cancelled!")
return f"[{current_process().name}] {task_num}: I was cancelled!"
async def handle_client(reader, writer):
try:
while True:
data = await reader.read(100)
if not data:
break
print(
f"[{current_process().name}] Received {data!r} from {writer.get_extra_info('peername')}"
)
writer.write(data)
await writer.drain()
except asyncio.CancelledError:
print(f"[{current_process().name}] Connection handler was cancelled")
finally:
writer.close()
await writer.wait_closed()
async def shutdown(loop, main_task, stop_event, server, server_tasks):
print(f"[{current_process().name}] Shutting down...")
server.close()
async def server_shutdown():
try:
await asyncio.wait_for(server.wait_closed(), timeout=1.0)
return f"[{current_process().name}] Client connections closed"
except asyncio.TimeoutError:
print(
f"[{current_process().name}] Server did not close within 1 second, forcing shutdown..."
)
for task in server_tasks:
task.cancel()
results = await asyncio.gather(*server_tasks, return_exceptions=True)
return f"[{current_process().name}] Client connections terminated"
tasks = [
t
for t in asyncio.all_tasks(loop)
if t is not main_task
and t is not asyncio.current_task()
and t not in server_tasks
]
for task in tasks:
task.cancel()
results = await asyncio.gather(*tasks, server_shutdown(), return_exceptions=True)
print(
f"[{current_process().name}] Finished awaiting cancelled tasks, results: {results}"
)
def add_signal_handlers(loop, main_task, stop_event, server, server_tasks):
is_shutting_down = False
async def main_signal_handler():
nonlocal is_shutting_down
if is_shutting_down:
return
is_shutting_down = True
print(f"[{current_process().name}] Received signal.")
await shutdown(loop, main_task, stop_event, server, server_tasks)
stop_event.set()
for sig in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(sig, lambda: asyncio.create_task(main_signal_handler()))
async def run_worker(stop_event, sock):
loop = asyncio.get_running_loop()
server_tasks = set()
for i in range(5):
asyncio.create_task(looping_task(i))
async def handle_client_wrapper(reader, writer):
task = asyncio.current_task()
server_tasks.add(task)
try:
await handle_client(reader, writer)
finally:
server_tasks.remove(task)
server = await asyncio.start_server(handle_client_wrapper, sock=sock)
print(f"[{current_process().name}] Serving on {server.sockets[0].getsockname()}")
add_signal_handlers(loop, asyncio.current_task(), stop_event, server, server_tasks)
try:
await stop_event.wait()
except asyncio.CancelledError:
print(f"[{current_process().name}] Main task cancelled")
print(f"[{current_process().name}] Cleaning up before exiting...")
print(f"[{current_process().name}] Shutdown complete")
def worker(stop_event, sock):
asyncio.run(run_worker(stop_event, sock))
async def serve_multi():
# Create a socket in the main process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(("127.0.0.1", 8888))
sock.listen(100)
# sock.setblocking(False)
sock.set_inheritable(True)
# Start worker processes
ctx = get_context("spawn")
workers = []
for i in range(1):
stop_event = asyncio.Event()
w = ctx.Process(target=worker, args=(stop_event, sock), name=f"Worker-{i}")
w.start()
workers.append(w)
async def terminate_workers():
for w in workers:
w.terminate()
print(f"[{current_process().name}] All workers terminated.")
loop = asyncio.get_running_loop()
for sig in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(sig, lambda: asyncio.create_task(terminate_workers()))
# Run one worker in the main process
stop_event = asyncio.Event()
await run_worker(stop_event, sock)
for w in workers:
w.join()
print(f"[{current_process().name}] All workers joined.")
async def main():
await serve_multi()
print(f"[{current_process().name}] Event loop has closed")
if __name__ == "__main__":
asyncio.run(main())
I think in reality, you'd be very unlikely to want to start cancelling all running tasks like this. Instead you probably just want to track the long running ones in a variable (a bit like the server client connection tasks) and just cancel those ones.
There is a subtle bug in the above code, which is that terminate_workers()
is never called because later the signal handlers are overridden. In this case this doesn't matter because the effect of SIGTERM or SIGINT is to terminate all the workers anyway but you won't see the message printed. In a future version I'll fix this problem.
Be the first to comment.
Copyright James Gardner 1996-2020 All Rights Reserved. Admin.