A minimalist async socket server for multiprocessing in Python
The pattern I will show below is very useful in the following scenario:
The pattern I will show below is very useful in the following scenario:
You are processing a very large batch using multiple asyncio tasks.
You have a function or routine you will need to call multiple times (like once for each of your records). This function is self-contained and does not need to synchronize with other workers.
This routine is either CPU-intensive and you have many CPU cores to make use of, or it’s I/O-intensive but not async in nature (like reading/writing to the disk).
Let’s say you have a producer-consumer pattern
This is a very common pattern in asyncio programs:
import os
import time
import asyncio
async def main():
start = time.time()
num_workers = os.cpu_count()
work_count = [0] * num_workers # how many items each worker processed
queue = asyncio.Queue(1024)
producer = asyncio.create_task(my_producer(queue, num_workers))
consumers = [asyncio.create_task(my_consumer(queue, work_count, i))
for i in range(num_workers)]
await asyncio.gather(producer, *consumers)
print(f'Finished in {time.time() - start:,.1f}s.')
print('Items processed per worker:')
print(', '.join(f'W{i}: {count}' for i, count in enumerate(work_count)))
async def my_producer(queue, num_workers):
for i in range(2000):
await queue.put(i)
# Add terminators to make all workers exit:
for i in range(num_workers):
await queue.put(None)
The coroutine my_producer()
will put items into the queue, and my_consumer()
will get items from it and process them. If the consumers are very CPU-intensive, or don’t yield to the event loop, you may end up with a pattern that is close to serial processing. Let me show an example:
async def my_consumer(queue, work_count, worker_number):
while True:
row = await queue.get()
try:
if row is None:
break
time.sleep(0.01) # A CPU-intensive task that takes ~10ms
work_count[worker_number] += 1
finally:
queue.task_done()
>>> asyncio.run(main())
Finished in 21.7s.
Items processed per worker:
W0: 2000, W1: 0, W2: 0, W3: 0, W4: 0, W5: 0, W6: 0, W7: 0, W8: 0
Regardless of how many CPU cores you have, this snippet will take just over 20 seconds to run (2000 * 0.01s). This is because the consumers never yield control to the event loop. Even calls to await queue.get()
return immediately because there are items in the queue.
The same thing would occur if your consumers opened a file for writing using open()
and file.write()
. These functions are not asynchronous and don’t yield to the event loop. There is no benefit in using asyncio like this.
Using multiprocessing
The cookbook suggestion is to use multiprocessing for cases like this. Their documentation gives an example that could be applied this way:
import os
import time
import asyncio
import concurrent.futures
async def main():
worker_count = os.cpu_count()
queue = asyncio.Queue(1024)
# Create a process pool to execute the work (pass it to the consumers):
with concurrent.futures.ProcessPoolExecutor() as pool:
producer = asyncio.create_task(my_producer(queue, worker_count))
consumers = [asyncio.create_task(my_consumer(queue, i, pool))
for i in range(worker_count)]
await asyncio.gather(producer, *consumers)
async def my_producer(queue, num_workers):
for i in range(2000):
await queue.put(i)
# Add terminators to make all workers exit:
for i in range(num_workers):
await queue.put(None)
async def my_consumer(queue, work_count, worker_number, pool):
loop = asyncio.get_running_loop()
while True:
row = await queue.get()
try:
if row is None:
break
result = await loop.run_in_executor(pool, time.sleep, 0.01)
work_count[worker_number] += 1
finally:
queue.task_done()
>>> asyncio.run(main())
Finished in 3.5s.
Items processed per worker:
W0: 250, W1: 250, W2: 251, W3: 250, W4: 250, W5: 250, W6: 250, W7: 249
Here, ProcessPoolExecutor is spawning multiple subprocesses to do the actual processing (of time.sleep(0.01)
, or your CPU-intensive process). This approach, however, has some disadvantages:
Data sent to/from the subprocesses is serialized via Pickle, which has improved a lot over time, but is not the fastest or most efficient serialization method. With a few thousand records, not a problem, but with several million, this adds up.
Some objects may not be pickled correctly, which rules out this solution altogether.
I’ve had my share of frustration with gotchas from using different start methods — that I will not go into detail now. As applications grow bigger and more complex, you can run into issues that you wouldn’t if you simply started a new, fresh subprocess.
Here is a more detailed article on why you wouldn’t want to use Pickle.
Using a minimalist socket server
Another approach is for every consumer to launch its own “mini-server” that communicates with them using a socket and a faster serialization method, for example orjson.
Create a file called miniserver.py
:
import sys
import time
import asyncio
import orjson
async def server_main(port, *options):
# Perform start up with options
try:
server = await asyncio.start_server(handle_request, 'localhost', port)
async with server:
await server.serve_forever()
finally:
pass # Perform tear down
async def handle_request(reader, writer):
while True:
# If an empty message is received, terminate the connection
if not (raw := await reader.readline()):
writer.close()
break
try:
data = orjson.loads(raw)
# Your CPU- or non-async I/O-intensive process
sleep_time = data['sleep_time']
time.sleep(sleep_time)
result = {'slept': sleep_time}
except (KeyboardInterrupt, asyncio.CancelledError):
break
except Exception as exc:
result = {'error': exc.__class__.__name__, 'message': str(exc)}
writer.write(orjson.dumps(result) + b'\n')
await writer.drain()
if __name__ == '__main__':
# Optional but recommended: if parent dies, this process should die too
try:
import prctl
import signal
prctl.set_pdeathsig(signal.SIGTERM)
except ImportError:
pass
try:
port = int(sys.argv[1])
options = sys.argv[2:]
except (IndexError, TypeError, ValueError):
sys.stderr.write('Usage: python miniserver.py PORT OPTIONS...\n')
sys.exit(1)
try:
asyncio.run(server_main(port, *options))
except KeyboardInterrupt:
sys.exit(0)
Just for fun, you can launch this and test it with telnet. Write a line with a JSON object and see its output:
$ python miniserver.py 10999 &
$ telnet localhost 10999
Trying 127.0.0.1...
Connected to localhost.
Escape character is '^]'.
{"sleep_time": 1}
{"slept":1}
{"sleep_time":10}
{"slept":10}
Then you can make your consumer call it using a custom port (everything else stays the same):
import orjson
async def my_consumer(queue, work_count, worker_number):
port = 11000 + worker_number
subproc, reader, writer = await launch_subprocess(port)
try:
while True:
row = await queue.get()
try:
if row is None:
break
# Send record to socket
body = orjson.dumps({'sleep_time': 0.01})
writer.write(body + b'\n')
await writer.drain()
# Then wait for the response
raw = await reader.readline()
result = orjson.loads(raw)
work_count[worker_number] += 1
finally:
queue.task_done()
finally:
writer.close()
await writer.wait_closed()
subproc.terminate()
subproc.wait()
async def launch_subprocess(port, *options):
cmdargs = ['miniserver.py', str(port), *options]
subproc = await asyncio.create_subprocess_exec(sys.executable, *cmdargs)
# This may be tricky: wait for the port to be open
# (especially if your startup is slow)
for i in range(50):
try:
reader, writer = await asyncio.open_connection('localhost', port)
break
except ConnectionRefusedError:
await asyncio.sleep(0.1)
else:
subproc.terminate()
raise RuntimeError(f'Server did not start after 5s: {cmdargs}')
return subproc, reader, writer
>>> asyncio.run(main())
Finished in 3.1s.
Items processed per worker:
W0: 91, W1: 305, W2: 296, W3: 290, W4: 277, W5: 260, W6: 249, W7: 232