xref: /OK3568_Linux_fs/yocto/poky/bitbake/lib/bb/asyncrpc/serv.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1#
2# Copyright BitBake Contributors
3#
4# SPDX-License-Identifier: GPL-2.0-only
5#
6
7import abc
8import asyncio
9import json
10import os
11import signal
12import socket
13import sys
14import multiprocessing
15from . import chunkify, DEFAULT_MAX_CHUNK
16
17
18class ClientError(Exception):
19    pass
20
21
22class ServerError(Exception):
23    pass
24
25
26class AsyncServerConnection(object):
27    def __init__(self, reader, writer, proto_name, logger):
28        self.reader = reader
29        self.writer = writer
30        self.proto_name = proto_name
31        self.max_chunk = DEFAULT_MAX_CHUNK
32        self.handlers = {
33            'chunk-stream': self.handle_chunk,
34            'ping': self.handle_ping,
35        }
36        self.logger = logger
37
38    async def process_requests(self):
39        try:
40            self.addr = self.writer.get_extra_info('peername')
41            self.logger.debug('Client %r connected' % (self.addr,))
42
43            # Read protocol and version
44            client_protocol = await self.reader.readline()
45            if client_protocol is None:
46                return
47
48            (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
49            if client_proto_name != self.proto_name:
50                self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
51                return
52
53            self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
54            if not self.validate_proto_version():
55                self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
56                return
57
58            # Read headers. Currently, no headers are implemented, so look for
59            # an empty line to signal the end of the headers
60            while True:
61                line = await self.reader.readline()
62                if line is None:
63                    return
64
65                line = line.decode('utf-8').rstrip()
66                if not line:
67                    break
68
69            # Handle messages
70            while True:
71                d = await self.read_message()
72                if d is None:
73                    break
74                await self.dispatch_message(d)
75                await self.writer.drain()
76        except ClientError as e:
77            self.logger.error(str(e))
78        finally:
79            self.writer.close()
80
81    async def dispatch_message(self, msg):
82        for k in self.handlers.keys():
83            if k in msg:
84                self.logger.debug('Handling %s' % k)
85                await self.handlers[k](msg[k])
86                return
87
88        raise ClientError("Unrecognized command %r" % msg)
89
90    def write_message(self, msg):
91        for c in chunkify(json.dumps(msg), self.max_chunk):
92            self.writer.write(c.encode('utf-8'))
93
94    async def read_message(self):
95        l = await self.reader.readline()
96        if not l:
97            return None
98
99        try:
100            message = l.decode('utf-8')
101
102            if not message.endswith('\n'):
103                return None
104
105            return json.loads(message)
106        except (json.JSONDecodeError, UnicodeDecodeError) as e:
107            self.logger.error('Bad message from client: %r' % message)
108            raise e
109
110    async def handle_chunk(self, request):
111        lines = []
112        try:
113            while True:
114                l = await self.reader.readline()
115                l = l.rstrip(b"\n").decode("utf-8")
116                if not l:
117                    break
118                lines.append(l)
119
120            msg = json.loads(''.join(lines))
121        except (json.JSONDecodeError, UnicodeDecodeError) as e:
122            self.logger.error('Bad message from client: %r' % lines)
123            raise e
124
125        if 'chunk-stream' in msg:
126            raise ClientError("Nested chunks are not allowed")
127
128        await self.dispatch_message(msg)
129
130    async def handle_ping(self, request):
131        response = {'alive': True}
132        self.write_message(response)
133
134
135class AsyncServer(object):
136    def __init__(self, logger):
137        self._cleanup_socket = None
138        self.logger = logger
139        self.start = None
140        self.address = None
141        self.loop = None
142
143    def start_tcp_server(self, host, port):
144        def start_tcp():
145            self.server = self.loop.run_until_complete(
146                asyncio.start_server(self.handle_client, host, port)
147            )
148
149            for s in self.server.sockets:
150                self.logger.debug('Listening on %r' % (s.getsockname(),))
151                # Newer python does this automatically. Do it manually here for
152                # maximum compatibility
153                s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
154                s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
155
156            name = self.server.sockets[0].getsockname()
157            if self.server.sockets[0].family == socket.AF_INET6:
158                self.address = "[%s]:%d" % (name[0], name[1])
159            else:
160                self.address = "%s:%d" % (name[0], name[1])
161
162        self.start = start_tcp
163
164    def start_unix_server(self, path):
165        def cleanup():
166            os.unlink(path)
167
168        def start_unix():
169            cwd = os.getcwd()
170            try:
171                # Work around path length limits in AF_UNIX
172                os.chdir(os.path.dirname(path))
173                self.server = self.loop.run_until_complete(
174                    asyncio.start_unix_server(self.handle_client, os.path.basename(path))
175                )
176            finally:
177                os.chdir(cwd)
178
179            self.logger.debug('Listening on %r' % path)
180
181            self._cleanup_socket = cleanup
182            self.address = "unix://%s" % os.path.abspath(path)
183
184        self.start = start_unix
185
186    @abc.abstractmethod
187    def accept_client(self, reader, writer):
188        pass
189
190    async def handle_client(self, reader, writer):
191        # writer.transport.set_write_buffer_limits(0)
192        try:
193            client = self.accept_client(reader, writer)
194            await client.process_requests()
195        except Exception as e:
196            import traceback
197            self.logger.error('Error from client: %s' % str(e), exc_info=True)
198            traceback.print_exc()
199            writer.close()
200        self.logger.debug('Client disconnected')
201
202    def run_loop_forever(self):
203        try:
204            self.loop.run_forever()
205        except KeyboardInterrupt:
206            pass
207
208    def signal_handler(self):
209        self.logger.debug("Got exit signal")
210        self.loop.stop()
211
212    def _serve_forever(self):
213        try:
214            self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
215            signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
216
217            self.run_loop_forever()
218            self.server.close()
219
220            self.loop.run_until_complete(self.server.wait_closed())
221            self.logger.debug('Server shutting down')
222        finally:
223            if self._cleanup_socket is not None:
224                self._cleanup_socket()
225
226    def serve_forever(self):
227        """
228        Serve requests in the current process
229        """
230        # Create loop and override any loop that may have existed in
231        # a parent process.  It is possible that the usecases of
232        # serve_forever might be constrained enough to allow using
233        # get_event_loop here, but better safe than sorry for now.
234        self.loop = asyncio.new_event_loop()
235        asyncio.set_event_loop(self.loop)
236        self.start()
237        self._serve_forever()
238
239    def serve_as_process(self, *, prefunc=None, args=()):
240        """
241        Serve requests in a child process
242        """
243        def run(queue):
244            # Create loop and override any loop that may have existed
245            # in a parent process.  Without doing this and instead
246            # using get_event_loop, at the very minimum the hashserv
247            # unit tests will hang when running the second test.
248            # This happens since get_event_loop in the spawned server
249            # process for the second testcase ends up with the loop
250            # from the hashserv client created in the unit test process
251            # when running the first testcase.  The problem is somewhat
252            # more general, though, as any potential use of asyncio in
253            # Cooker could create a loop that needs to replaced in this
254            # new process.
255            self.loop = asyncio.new_event_loop()
256            asyncio.set_event_loop(self.loop)
257            try:
258                self.start()
259            finally:
260                queue.put(self.address)
261                queue.close()
262
263            if prefunc is not None:
264                prefunc(self, *args)
265
266            self._serve_forever()
267
268            if sys.version_info >= (3, 6):
269                self.loop.run_until_complete(self.loop.shutdown_asyncgens())
270            self.loop.close()
271
272        queue = multiprocessing.Queue()
273
274        # Temporarily block SIGTERM. The server process will inherit this
275        # block which will ensure it doesn't receive the SIGTERM until the
276        # handler is ready for it
277        mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
278        try:
279            self.process = multiprocessing.Process(target=run, args=(queue,))
280            self.process.start()
281
282            self.address = queue.get()
283            queue.close()
284            queue.join_thread()
285
286            return self.process
287        finally:
288            signal.pthread_sigmask(signal.SIG_SETMASK, mask)
289