1*4882a593Smuzhiyun# 2*4882a593Smuzhiyun# Copyright BitBake Contributors 3*4882a593Smuzhiyun# 4*4882a593Smuzhiyun# SPDX-License-Identifier: GPL-2.0-only 5*4882a593Smuzhiyun# 6*4882a593Smuzhiyun 7*4882a593Smuzhiyunimport abc 8*4882a593Smuzhiyunimport asyncio 9*4882a593Smuzhiyunimport json 10*4882a593Smuzhiyunimport os 11*4882a593Smuzhiyunimport signal 12*4882a593Smuzhiyunimport socket 13*4882a593Smuzhiyunimport sys 14*4882a593Smuzhiyunimport multiprocessing 15*4882a593Smuzhiyunfrom . import chunkify, DEFAULT_MAX_CHUNK 16*4882a593Smuzhiyun 17*4882a593Smuzhiyun 18*4882a593Smuzhiyunclass ClientError(Exception): 19*4882a593Smuzhiyun pass 20*4882a593Smuzhiyun 21*4882a593Smuzhiyun 22*4882a593Smuzhiyunclass ServerError(Exception): 23*4882a593Smuzhiyun pass 24*4882a593Smuzhiyun 25*4882a593Smuzhiyun 26*4882a593Smuzhiyunclass AsyncServerConnection(object): 27*4882a593Smuzhiyun def __init__(self, reader, writer, proto_name, logger): 28*4882a593Smuzhiyun self.reader = reader 29*4882a593Smuzhiyun self.writer = writer 30*4882a593Smuzhiyun self.proto_name = proto_name 31*4882a593Smuzhiyun self.max_chunk = DEFAULT_MAX_CHUNK 32*4882a593Smuzhiyun self.handlers = { 33*4882a593Smuzhiyun 'chunk-stream': self.handle_chunk, 34*4882a593Smuzhiyun 'ping': self.handle_ping, 35*4882a593Smuzhiyun } 36*4882a593Smuzhiyun self.logger = logger 37*4882a593Smuzhiyun 38*4882a593Smuzhiyun async def process_requests(self): 39*4882a593Smuzhiyun try: 40*4882a593Smuzhiyun self.addr = self.writer.get_extra_info('peername') 41*4882a593Smuzhiyun self.logger.debug('Client %r connected' % (self.addr,)) 42*4882a593Smuzhiyun 43*4882a593Smuzhiyun # Read protocol and version 44*4882a593Smuzhiyun client_protocol = await self.reader.readline() 45*4882a593Smuzhiyun if client_protocol is None: 46*4882a593Smuzhiyun return 47*4882a593Smuzhiyun 48*4882a593Smuzhiyun (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() 49*4882a593Smuzhiyun if client_proto_name != self.proto_name: 50*4882a593Smuzhiyun self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) 51*4882a593Smuzhiyun return 52*4882a593Smuzhiyun 53*4882a593Smuzhiyun self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) 54*4882a593Smuzhiyun if not self.validate_proto_version(): 55*4882a593Smuzhiyun self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) 56*4882a593Smuzhiyun return 57*4882a593Smuzhiyun 58*4882a593Smuzhiyun # Read headers. Currently, no headers are implemented, so look for 59*4882a593Smuzhiyun # an empty line to signal the end of the headers 60*4882a593Smuzhiyun while True: 61*4882a593Smuzhiyun line = await self.reader.readline() 62*4882a593Smuzhiyun if line is None: 63*4882a593Smuzhiyun return 64*4882a593Smuzhiyun 65*4882a593Smuzhiyun line = line.decode('utf-8').rstrip() 66*4882a593Smuzhiyun if not line: 67*4882a593Smuzhiyun break 68*4882a593Smuzhiyun 69*4882a593Smuzhiyun # Handle messages 70*4882a593Smuzhiyun while True: 71*4882a593Smuzhiyun d = await self.read_message() 72*4882a593Smuzhiyun if d is None: 73*4882a593Smuzhiyun break 74*4882a593Smuzhiyun await self.dispatch_message(d) 75*4882a593Smuzhiyun await self.writer.drain() 76*4882a593Smuzhiyun except ClientError as e: 77*4882a593Smuzhiyun self.logger.error(str(e)) 78*4882a593Smuzhiyun finally: 79*4882a593Smuzhiyun self.writer.close() 80*4882a593Smuzhiyun 81*4882a593Smuzhiyun async def dispatch_message(self, msg): 82*4882a593Smuzhiyun for k in self.handlers.keys(): 83*4882a593Smuzhiyun if k in msg: 84*4882a593Smuzhiyun self.logger.debug('Handling %s' % k) 85*4882a593Smuzhiyun await self.handlers[k](msg[k]) 86*4882a593Smuzhiyun return 87*4882a593Smuzhiyun 88*4882a593Smuzhiyun raise ClientError("Unrecognized command %r" % msg) 89*4882a593Smuzhiyun 90*4882a593Smuzhiyun def write_message(self, msg): 91*4882a593Smuzhiyun for c in chunkify(json.dumps(msg), self.max_chunk): 92*4882a593Smuzhiyun self.writer.write(c.encode('utf-8')) 93*4882a593Smuzhiyun 94*4882a593Smuzhiyun async def read_message(self): 95*4882a593Smuzhiyun l = await self.reader.readline() 96*4882a593Smuzhiyun if not l: 97*4882a593Smuzhiyun return None 98*4882a593Smuzhiyun 99*4882a593Smuzhiyun try: 100*4882a593Smuzhiyun message = l.decode('utf-8') 101*4882a593Smuzhiyun 102*4882a593Smuzhiyun if not message.endswith('\n'): 103*4882a593Smuzhiyun return None 104*4882a593Smuzhiyun 105*4882a593Smuzhiyun return json.loads(message) 106*4882a593Smuzhiyun except (json.JSONDecodeError, UnicodeDecodeError) as e: 107*4882a593Smuzhiyun self.logger.error('Bad message from client: %r' % message) 108*4882a593Smuzhiyun raise e 109*4882a593Smuzhiyun 110*4882a593Smuzhiyun async def handle_chunk(self, request): 111*4882a593Smuzhiyun lines = [] 112*4882a593Smuzhiyun try: 113*4882a593Smuzhiyun while True: 114*4882a593Smuzhiyun l = await self.reader.readline() 115*4882a593Smuzhiyun l = l.rstrip(b"\n").decode("utf-8") 116*4882a593Smuzhiyun if not l: 117*4882a593Smuzhiyun break 118*4882a593Smuzhiyun lines.append(l) 119*4882a593Smuzhiyun 120*4882a593Smuzhiyun msg = json.loads(''.join(lines)) 121*4882a593Smuzhiyun except (json.JSONDecodeError, UnicodeDecodeError) as e: 122*4882a593Smuzhiyun self.logger.error('Bad message from client: %r' % lines) 123*4882a593Smuzhiyun raise e 124*4882a593Smuzhiyun 125*4882a593Smuzhiyun if 'chunk-stream' in msg: 126*4882a593Smuzhiyun raise ClientError("Nested chunks are not allowed") 127*4882a593Smuzhiyun 128*4882a593Smuzhiyun await self.dispatch_message(msg) 129*4882a593Smuzhiyun 130*4882a593Smuzhiyun async def handle_ping(self, request): 131*4882a593Smuzhiyun response = {'alive': True} 132*4882a593Smuzhiyun self.write_message(response) 133*4882a593Smuzhiyun 134*4882a593Smuzhiyun 135*4882a593Smuzhiyunclass AsyncServer(object): 136*4882a593Smuzhiyun def __init__(self, logger): 137*4882a593Smuzhiyun self._cleanup_socket = None 138*4882a593Smuzhiyun self.logger = logger 139*4882a593Smuzhiyun self.start = None 140*4882a593Smuzhiyun self.address = None 141*4882a593Smuzhiyun self.loop = None 142*4882a593Smuzhiyun 143*4882a593Smuzhiyun def start_tcp_server(self, host, port): 144*4882a593Smuzhiyun def start_tcp(): 145*4882a593Smuzhiyun self.server = self.loop.run_until_complete( 146*4882a593Smuzhiyun asyncio.start_server(self.handle_client, host, port) 147*4882a593Smuzhiyun ) 148*4882a593Smuzhiyun 149*4882a593Smuzhiyun for s in self.server.sockets: 150*4882a593Smuzhiyun self.logger.debug('Listening on %r' % (s.getsockname(),)) 151*4882a593Smuzhiyun # Newer python does this automatically. Do it manually here for 152*4882a593Smuzhiyun # maximum compatibility 153*4882a593Smuzhiyun s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) 154*4882a593Smuzhiyun s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) 155*4882a593Smuzhiyun 156*4882a593Smuzhiyun name = self.server.sockets[0].getsockname() 157*4882a593Smuzhiyun if self.server.sockets[0].family == socket.AF_INET6: 158*4882a593Smuzhiyun self.address = "[%s]:%d" % (name[0], name[1]) 159*4882a593Smuzhiyun else: 160*4882a593Smuzhiyun self.address = "%s:%d" % (name[0], name[1]) 161*4882a593Smuzhiyun 162*4882a593Smuzhiyun self.start = start_tcp 163*4882a593Smuzhiyun 164*4882a593Smuzhiyun def start_unix_server(self, path): 165*4882a593Smuzhiyun def cleanup(): 166*4882a593Smuzhiyun os.unlink(path) 167*4882a593Smuzhiyun 168*4882a593Smuzhiyun def start_unix(): 169*4882a593Smuzhiyun cwd = os.getcwd() 170*4882a593Smuzhiyun try: 171*4882a593Smuzhiyun # Work around path length limits in AF_UNIX 172*4882a593Smuzhiyun os.chdir(os.path.dirname(path)) 173*4882a593Smuzhiyun self.server = self.loop.run_until_complete( 174*4882a593Smuzhiyun asyncio.start_unix_server(self.handle_client, os.path.basename(path)) 175*4882a593Smuzhiyun ) 176*4882a593Smuzhiyun finally: 177*4882a593Smuzhiyun os.chdir(cwd) 178*4882a593Smuzhiyun 179*4882a593Smuzhiyun self.logger.debug('Listening on %r' % path) 180*4882a593Smuzhiyun 181*4882a593Smuzhiyun self._cleanup_socket = cleanup 182*4882a593Smuzhiyun self.address = "unix://%s" % os.path.abspath(path) 183*4882a593Smuzhiyun 184*4882a593Smuzhiyun self.start = start_unix 185*4882a593Smuzhiyun 186*4882a593Smuzhiyun @abc.abstractmethod 187*4882a593Smuzhiyun def accept_client(self, reader, writer): 188*4882a593Smuzhiyun pass 189*4882a593Smuzhiyun 190*4882a593Smuzhiyun async def handle_client(self, reader, writer): 191*4882a593Smuzhiyun # writer.transport.set_write_buffer_limits(0) 192*4882a593Smuzhiyun try: 193*4882a593Smuzhiyun client = self.accept_client(reader, writer) 194*4882a593Smuzhiyun await client.process_requests() 195*4882a593Smuzhiyun except Exception as e: 196*4882a593Smuzhiyun import traceback 197*4882a593Smuzhiyun self.logger.error('Error from client: %s' % str(e), exc_info=True) 198*4882a593Smuzhiyun traceback.print_exc() 199*4882a593Smuzhiyun writer.close() 200*4882a593Smuzhiyun self.logger.debug('Client disconnected') 201*4882a593Smuzhiyun 202*4882a593Smuzhiyun def run_loop_forever(self): 203*4882a593Smuzhiyun try: 204*4882a593Smuzhiyun self.loop.run_forever() 205*4882a593Smuzhiyun except KeyboardInterrupt: 206*4882a593Smuzhiyun pass 207*4882a593Smuzhiyun 208*4882a593Smuzhiyun def signal_handler(self): 209*4882a593Smuzhiyun self.logger.debug("Got exit signal") 210*4882a593Smuzhiyun self.loop.stop() 211*4882a593Smuzhiyun 212*4882a593Smuzhiyun def _serve_forever(self): 213*4882a593Smuzhiyun try: 214*4882a593Smuzhiyun self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) 215*4882a593Smuzhiyun signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) 216*4882a593Smuzhiyun 217*4882a593Smuzhiyun self.run_loop_forever() 218*4882a593Smuzhiyun self.server.close() 219*4882a593Smuzhiyun 220*4882a593Smuzhiyun self.loop.run_until_complete(self.server.wait_closed()) 221*4882a593Smuzhiyun self.logger.debug('Server shutting down') 222*4882a593Smuzhiyun finally: 223*4882a593Smuzhiyun if self._cleanup_socket is not None: 224*4882a593Smuzhiyun self._cleanup_socket() 225*4882a593Smuzhiyun 226*4882a593Smuzhiyun def serve_forever(self): 227*4882a593Smuzhiyun """ 228*4882a593Smuzhiyun Serve requests in the current process 229*4882a593Smuzhiyun """ 230*4882a593Smuzhiyun # Create loop and override any loop that may have existed in 231*4882a593Smuzhiyun # a parent process. It is possible that the usecases of 232*4882a593Smuzhiyun # serve_forever might be constrained enough to allow using 233*4882a593Smuzhiyun # get_event_loop here, but better safe than sorry for now. 234*4882a593Smuzhiyun self.loop = asyncio.new_event_loop() 235*4882a593Smuzhiyun asyncio.set_event_loop(self.loop) 236*4882a593Smuzhiyun self.start() 237*4882a593Smuzhiyun self._serve_forever() 238*4882a593Smuzhiyun 239*4882a593Smuzhiyun def serve_as_process(self, *, prefunc=None, args=()): 240*4882a593Smuzhiyun """ 241*4882a593Smuzhiyun Serve requests in a child process 242*4882a593Smuzhiyun """ 243*4882a593Smuzhiyun def run(queue): 244*4882a593Smuzhiyun # Create loop and override any loop that may have existed 245*4882a593Smuzhiyun # in a parent process. Without doing this and instead 246*4882a593Smuzhiyun # using get_event_loop, at the very minimum the hashserv 247*4882a593Smuzhiyun # unit tests will hang when running the second test. 248*4882a593Smuzhiyun # This happens since get_event_loop in the spawned server 249*4882a593Smuzhiyun # process for the second testcase ends up with the loop 250*4882a593Smuzhiyun # from the hashserv client created in the unit test process 251*4882a593Smuzhiyun # when running the first testcase. The problem is somewhat 252*4882a593Smuzhiyun # more general, though, as any potential use of asyncio in 253*4882a593Smuzhiyun # Cooker could create a loop that needs to replaced in this 254*4882a593Smuzhiyun # new process. 255*4882a593Smuzhiyun self.loop = asyncio.new_event_loop() 256*4882a593Smuzhiyun asyncio.set_event_loop(self.loop) 257*4882a593Smuzhiyun try: 258*4882a593Smuzhiyun self.start() 259*4882a593Smuzhiyun finally: 260*4882a593Smuzhiyun queue.put(self.address) 261*4882a593Smuzhiyun queue.close() 262*4882a593Smuzhiyun 263*4882a593Smuzhiyun if prefunc is not None: 264*4882a593Smuzhiyun prefunc(self, *args) 265*4882a593Smuzhiyun 266*4882a593Smuzhiyun self._serve_forever() 267*4882a593Smuzhiyun 268*4882a593Smuzhiyun if sys.version_info >= (3, 6): 269*4882a593Smuzhiyun self.loop.run_until_complete(self.loop.shutdown_asyncgens()) 270*4882a593Smuzhiyun self.loop.close() 271*4882a593Smuzhiyun 272*4882a593Smuzhiyun queue = multiprocessing.Queue() 273*4882a593Smuzhiyun 274*4882a593Smuzhiyun # Temporarily block SIGTERM. The server process will inherit this 275*4882a593Smuzhiyun # block which will ensure it doesn't receive the SIGTERM until the 276*4882a593Smuzhiyun # handler is ready for it 277*4882a593Smuzhiyun mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) 278*4882a593Smuzhiyun try: 279*4882a593Smuzhiyun self.process = multiprocessing.Process(target=run, args=(queue,)) 280*4882a593Smuzhiyun self.process.start() 281*4882a593Smuzhiyun 282*4882a593Smuzhiyun self.address = queue.get() 283*4882a593Smuzhiyun queue.close() 284*4882a593Smuzhiyun queue.join_thread() 285*4882a593Smuzhiyun 286*4882a593Smuzhiyun return self.process 287*4882a593Smuzhiyun finally: 288*4882a593Smuzhiyun signal.pthread_sigmask(signal.SIG_SETMASK, mask) 289