1# 2# Copyright BitBake Contributors 3# 4# SPDX-License-Identifier: GPL-2.0-only 5# 6 7import abc 8import asyncio 9import json 10import os 11import signal 12import socket 13import sys 14import multiprocessing 15from . import chunkify, DEFAULT_MAX_CHUNK 16 17 18class ClientError(Exception): 19 pass 20 21 22class ServerError(Exception): 23 pass 24 25 26class AsyncServerConnection(object): 27 def __init__(self, reader, writer, proto_name, logger): 28 self.reader = reader 29 self.writer = writer 30 self.proto_name = proto_name 31 self.max_chunk = DEFAULT_MAX_CHUNK 32 self.handlers = { 33 'chunk-stream': self.handle_chunk, 34 'ping': self.handle_ping, 35 } 36 self.logger = logger 37 38 async def process_requests(self): 39 try: 40 self.addr = self.writer.get_extra_info('peername') 41 self.logger.debug('Client %r connected' % (self.addr,)) 42 43 # Read protocol and version 44 client_protocol = await self.reader.readline() 45 if client_protocol is None: 46 return 47 48 (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() 49 if client_proto_name != self.proto_name: 50 self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) 51 return 52 53 self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) 54 if not self.validate_proto_version(): 55 self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) 56 return 57 58 # Read headers. Currently, no headers are implemented, so look for 59 # an empty line to signal the end of the headers 60 while True: 61 line = await self.reader.readline() 62 if line is None: 63 return 64 65 line = line.decode('utf-8').rstrip() 66 if not line: 67 break 68 69 # Handle messages 70 while True: 71 d = await self.read_message() 72 if d is None: 73 break 74 await self.dispatch_message(d) 75 await self.writer.drain() 76 except ClientError as e: 77 self.logger.error(str(e)) 78 finally: 79 self.writer.close() 80 81 async def dispatch_message(self, msg): 82 for k in self.handlers.keys(): 83 if k in msg: 84 self.logger.debug('Handling %s' % k) 85 await self.handlers[k](msg[k]) 86 return 87 88 raise ClientError("Unrecognized command %r" % msg) 89 90 def write_message(self, msg): 91 for c in chunkify(json.dumps(msg), self.max_chunk): 92 self.writer.write(c.encode('utf-8')) 93 94 async def read_message(self): 95 l = await self.reader.readline() 96 if not l: 97 return None 98 99 try: 100 message = l.decode('utf-8') 101 102 if not message.endswith('\n'): 103 return None 104 105 return json.loads(message) 106 except (json.JSONDecodeError, UnicodeDecodeError) as e: 107 self.logger.error('Bad message from client: %r' % message) 108 raise e 109 110 async def handle_chunk(self, request): 111 lines = [] 112 try: 113 while True: 114 l = await self.reader.readline() 115 l = l.rstrip(b"\n").decode("utf-8") 116 if not l: 117 break 118 lines.append(l) 119 120 msg = json.loads(''.join(lines)) 121 except (json.JSONDecodeError, UnicodeDecodeError) as e: 122 self.logger.error('Bad message from client: %r' % lines) 123 raise e 124 125 if 'chunk-stream' in msg: 126 raise ClientError("Nested chunks are not allowed") 127 128 await self.dispatch_message(msg) 129 130 async def handle_ping(self, request): 131 response = {'alive': True} 132 self.write_message(response) 133 134 135class AsyncServer(object): 136 def __init__(self, logger): 137 self._cleanup_socket = None 138 self.logger = logger 139 self.start = None 140 self.address = None 141 self.loop = None 142 143 def start_tcp_server(self, host, port): 144 def start_tcp(): 145 self.server = self.loop.run_until_complete( 146 asyncio.start_server(self.handle_client, host, port) 147 ) 148 149 for s in self.server.sockets: 150 self.logger.debug('Listening on %r' % (s.getsockname(),)) 151 # Newer python does this automatically. Do it manually here for 152 # maximum compatibility 153 s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 154 s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 155 156 name = self.server.sockets[0].getsockname() 157 if self.server.sockets[0].family == socket.AF_INET6: 158 self.address = "[%s]:%d" % (name[0], name[1]) 159 else: 160 self.address = "%s:%d" % (name[0], name[1]) 161 162 self.start = start_tcp 163 164 def start_unix_server(self, path): 165 def cleanup(): 166 os.unlink(path) 167 168 def start_unix(): 169 cwd = os.getcwd() 170 try: 171 # Work around path length limits in AF_UNIX 172 os.chdir(os.path.dirname(path)) 173 self.server = self.loop.run_until_complete( 174 asyncio.start_unix_server(self.handle_client, os.path.basename(path)) 175 ) 176 finally: 177 os.chdir(cwd) 178 179 self.logger.debug('Listening on %r' % path) 180 181 self._cleanup_socket = cleanup 182 self.address = "unix://%s" % os.path.abspath(path) 183 184 self.start = start_unix 185 186 @abc.abstractmethod 187 def accept_client(self, reader, writer): 188 pass 189 190 async def handle_client(self, reader, writer): 191 # writer.transport.set_write_buffer_limits(0) 192 try: 193 client = self.accept_client(reader, writer) 194 await client.process_requests() 195 except Exception as e: 196 import traceback 197 self.logger.error('Error from client: %s' % str(e), exc_info=True) 198 traceback.print_exc() 199 writer.close() 200 self.logger.debug('Client disconnected') 201 202 def run_loop_forever(self): 203 try: 204 self.loop.run_forever() 205 except KeyboardInterrupt: 206 pass 207 208 def signal_handler(self): 209 self.logger.debug("Got exit signal") 210 self.loop.stop() 211 212 def _serve_forever(self): 213 try: 214 self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 215 signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 216 217 self.run_loop_forever() 218 self.server.close() 219 220 self.loop.run_until_complete(self.server.wait_closed()) 221 self.logger.debug('Server shutting down') 222 finally: 223 if self._cleanup_socket is not None: 224 self._cleanup_socket() 225 226 def serve_forever(self): 227 """ 228 Serve requests in the current process 229 """ 230 # Create loop and override any loop that may have existed in 231 # a parent process. It is possible that the usecases of 232 # serve_forever might be constrained enough to allow using 233 # get_event_loop here, but better safe than sorry for now. 234 self.loop = asyncio.new_event_loop() 235 asyncio.set_event_loop(self.loop) 236 self.start() 237 self._serve_forever() 238 239 def serve_as_process(self, *, prefunc=None, args=()): 240 """ 241 Serve requests in a child process 242 """ 243 def run(queue): 244 # Create loop and override any loop that may have existed 245 # in a parent process. Without doing this and instead 246 # using get_event_loop, at the very minimum the hashserv 247 # unit tests will hang when running the second test. 248 # This happens since get_event_loop in the spawned server 249 # process for the second testcase ends up with the loop 250 # from the hashserv client created in the unit test process 251 # when running the first testcase. The problem is somewhat 252 # more general, though, as any potential use of asyncio in 253 # Cooker could create a loop that needs to replaced in this 254 # new process. 255 self.loop = asyncio.new_event_loop() 256 asyncio.set_event_loop(self.loop) 257 try: 258 self.start() 259 finally: 260 queue.put(self.address) 261 queue.close() 262 263 if prefunc is not None: 264 prefunc(self, *args) 265 266 self._serve_forever() 267 268 if sys.version_info >= (3, 6): 269 self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 270 self.loop.close() 271 272 queue = multiprocessing.Queue() 273 274 # Temporarily block SIGTERM. The server process will inherit this 275 # block which will ensure it doesn't receive the SIGTERM until the 276 # handler is ready for it 277 mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) 278 try: 279 self.process = multiprocessing.Process(target=run, args=(queue,)) 280 self.process.start() 281 282 self.address = queue.get() 283 queue.close() 284 queue.join_thread() 285 286 return self.process 287 finally: 288 signal.pthread_sigmask(signal.SIG_SETMASK, mask) 289