xref: /OK3568_Linux_fs/yocto/poky/bitbake/lib/bb/asyncrpc/serv.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun#
2*4882a593Smuzhiyun# Copyright BitBake Contributors
3*4882a593Smuzhiyun#
4*4882a593Smuzhiyun# SPDX-License-Identifier: GPL-2.0-only
5*4882a593Smuzhiyun#
6*4882a593Smuzhiyun
7*4882a593Smuzhiyunimport abc
8*4882a593Smuzhiyunimport asyncio
9*4882a593Smuzhiyunimport json
10*4882a593Smuzhiyunimport os
11*4882a593Smuzhiyunimport signal
12*4882a593Smuzhiyunimport socket
13*4882a593Smuzhiyunimport sys
14*4882a593Smuzhiyunimport multiprocessing
15*4882a593Smuzhiyunfrom . import chunkify, DEFAULT_MAX_CHUNK
16*4882a593Smuzhiyun
17*4882a593Smuzhiyun
18*4882a593Smuzhiyunclass ClientError(Exception):
19*4882a593Smuzhiyun    pass
20*4882a593Smuzhiyun
21*4882a593Smuzhiyun
22*4882a593Smuzhiyunclass ServerError(Exception):
23*4882a593Smuzhiyun    pass
24*4882a593Smuzhiyun
25*4882a593Smuzhiyun
26*4882a593Smuzhiyunclass AsyncServerConnection(object):
27*4882a593Smuzhiyun    def __init__(self, reader, writer, proto_name, logger):
28*4882a593Smuzhiyun        self.reader = reader
29*4882a593Smuzhiyun        self.writer = writer
30*4882a593Smuzhiyun        self.proto_name = proto_name
31*4882a593Smuzhiyun        self.max_chunk = DEFAULT_MAX_CHUNK
32*4882a593Smuzhiyun        self.handlers = {
33*4882a593Smuzhiyun            'chunk-stream': self.handle_chunk,
34*4882a593Smuzhiyun            'ping': self.handle_ping,
35*4882a593Smuzhiyun        }
36*4882a593Smuzhiyun        self.logger = logger
37*4882a593Smuzhiyun
38*4882a593Smuzhiyun    async def process_requests(self):
39*4882a593Smuzhiyun        try:
40*4882a593Smuzhiyun            self.addr = self.writer.get_extra_info('peername')
41*4882a593Smuzhiyun            self.logger.debug('Client %r connected' % (self.addr,))
42*4882a593Smuzhiyun
43*4882a593Smuzhiyun            # Read protocol and version
44*4882a593Smuzhiyun            client_protocol = await self.reader.readline()
45*4882a593Smuzhiyun            if client_protocol is None:
46*4882a593Smuzhiyun                return
47*4882a593Smuzhiyun
48*4882a593Smuzhiyun            (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
49*4882a593Smuzhiyun            if client_proto_name != self.proto_name:
50*4882a593Smuzhiyun                self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
51*4882a593Smuzhiyun                return
52*4882a593Smuzhiyun
53*4882a593Smuzhiyun            self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
54*4882a593Smuzhiyun            if not self.validate_proto_version():
55*4882a593Smuzhiyun                self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
56*4882a593Smuzhiyun                return
57*4882a593Smuzhiyun
58*4882a593Smuzhiyun            # Read headers. Currently, no headers are implemented, so look for
59*4882a593Smuzhiyun            # an empty line to signal the end of the headers
60*4882a593Smuzhiyun            while True:
61*4882a593Smuzhiyun                line = await self.reader.readline()
62*4882a593Smuzhiyun                if line is None:
63*4882a593Smuzhiyun                    return
64*4882a593Smuzhiyun
65*4882a593Smuzhiyun                line = line.decode('utf-8').rstrip()
66*4882a593Smuzhiyun                if not line:
67*4882a593Smuzhiyun                    break
68*4882a593Smuzhiyun
69*4882a593Smuzhiyun            # Handle messages
70*4882a593Smuzhiyun            while True:
71*4882a593Smuzhiyun                d = await self.read_message()
72*4882a593Smuzhiyun                if d is None:
73*4882a593Smuzhiyun                    break
74*4882a593Smuzhiyun                await self.dispatch_message(d)
75*4882a593Smuzhiyun                await self.writer.drain()
76*4882a593Smuzhiyun        except ClientError as e:
77*4882a593Smuzhiyun            self.logger.error(str(e))
78*4882a593Smuzhiyun        finally:
79*4882a593Smuzhiyun            self.writer.close()
80*4882a593Smuzhiyun
81*4882a593Smuzhiyun    async def dispatch_message(self, msg):
82*4882a593Smuzhiyun        for k in self.handlers.keys():
83*4882a593Smuzhiyun            if k in msg:
84*4882a593Smuzhiyun                self.logger.debug('Handling %s' % k)
85*4882a593Smuzhiyun                await self.handlers[k](msg[k])
86*4882a593Smuzhiyun                return
87*4882a593Smuzhiyun
88*4882a593Smuzhiyun        raise ClientError("Unrecognized command %r" % msg)
89*4882a593Smuzhiyun
90*4882a593Smuzhiyun    def write_message(self, msg):
91*4882a593Smuzhiyun        for c in chunkify(json.dumps(msg), self.max_chunk):
92*4882a593Smuzhiyun            self.writer.write(c.encode('utf-8'))
93*4882a593Smuzhiyun
94*4882a593Smuzhiyun    async def read_message(self):
95*4882a593Smuzhiyun        l = await self.reader.readline()
96*4882a593Smuzhiyun        if not l:
97*4882a593Smuzhiyun            return None
98*4882a593Smuzhiyun
99*4882a593Smuzhiyun        try:
100*4882a593Smuzhiyun            message = l.decode('utf-8')
101*4882a593Smuzhiyun
102*4882a593Smuzhiyun            if not message.endswith('\n'):
103*4882a593Smuzhiyun                return None
104*4882a593Smuzhiyun
105*4882a593Smuzhiyun            return json.loads(message)
106*4882a593Smuzhiyun        except (json.JSONDecodeError, UnicodeDecodeError) as e:
107*4882a593Smuzhiyun            self.logger.error('Bad message from client: %r' % message)
108*4882a593Smuzhiyun            raise e
109*4882a593Smuzhiyun
110*4882a593Smuzhiyun    async def handle_chunk(self, request):
111*4882a593Smuzhiyun        lines = []
112*4882a593Smuzhiyun        try:
113*4882a593Smuzhiyun            while True:
114*4882a593Smuzhiyun                l = await self.reader.readline()
115*4882a593Smuzhiyun                l = l.rstrip(b"\n").decode("utf-8")
116*4882a593Smuzhiyun                if not l:
117*4882a593Smuzhiyun                    break
118*4882a593Smuzhiyun                lines.append(l)
119*4882a593Smuzhiyun
120*4882a593Smuzhiyun            msg = json.loads(''.join(lines))
121*4882a593Smuzhiyun        except (json.JSONDecodeError, UnicodeDecodeError) as e:
122*4882a593Smuzhiyun            self.logger.error('Bad message from client: %r' % lines)
123*4882a593Smuzhiyun            raise e
124*4882a593Smuzhiyun
125*4882a593Smuzhiyun        if 'chunk-stream' in msg:
126*4882a593Smuzhiyun            raise ClientError("Nested chunks are not allowed")
127*4882a593Smuzhiyun
128*4882a593Smuzhiyun        await self.dispatch_message(msg)
129*4882a593Smuzhiyun
130*4882a593Smuzhiyun    async def handle_ping(self, request):
131*4882a593Smuzhiyun        response = {'alive': True}
132*4882a593Smuzhiyun        self.write_message(response)
133*4882a593Smuzhiyun
134*4882a593Smuzhiyun
135*4882a593Smuzhiyunclass AsyncServer(object):
136*4882a593Smuzhiyun    def __init__(self, logger):
137*4882a593Smuzhiyun        self._cleanup_socket = None
138*4882a593Smuzhiyun        self.logger = logger
139*4882a593Smuzhiyun        self.start = None
140*4882a593Smuzhiyun        self.address = None
141*4882a593Smuzhiyun        self.loop = None
142*4882a593Smuzhiyun
143*4882a593Smuzhiyun    def start_tcp_server(self, host, port):
144*4882a593Smuzhiyun        def start_tcp():
145*4882a593Smuzhiyun            self.server = self.loop.run_until_complete(
146*4882a593Smuzhiyun                asyncio.start_server(self.handle_client, host, port)
147*4882a593Smuzhiyun            )
148*4882a593Smuzhiyun
149*4882a593Smuzhiyun            for s in self.server.sockets:
150*4882a593Smuzhiyun                self.logger.debug('Listening on %r' % (s.getsockname(),))
151*4882a593Smuzhiyun                # Newer python does this automatically. Do it manually here for
152*4882a593Smuzhiyun                # maximum compatibility
153*4882a593Smuzhiyun                s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
154*4882a593Smuzhiyun                s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
155*4882a593Smuzhiyun
156*4882a593Smuzhiyun            name = self.server.sockets[0].getsockname()
157*4882a593Smuzhiyun            if self.server.sockets[0].family == socket.AF_INET6:
158*4882a593Smuzhiyun                self.address = "[%s]:%d" % (name[0], name[1])
159*4882a593Smuzhiyun            else:
160*4882a593Smuzhiyun                self.address = "%s:%d" % (name[0], name[1])
161*4882a593Smuzhiyun
162*4882a593Smuzhiyun        self.start = start_tcp
163*4882a593Smuzhiyun
164*4882a593Smuzhiyun    def start_unix_server(self, path):
165*4882a593Smuzhiyun        def cleanup():
166*4882a593Smuzhiyun            os.unlink(path)
167*4882a593Smuzhiyun
168*4882a593Smuzhiyun        def start_unix():
169*4882a593Smuzhiyun            cwd = os.getcwd()
170*4882a593Smuzhiyun            try:
171*4882a593Smuzhiyun                # Work around path length limits in AF_UNIX
172*4882a593Smuzhiyun                os.chdir(os.path.dirname(path))
173*4882a593Smuzhiyun                self.server = self.loop.run_until_complete(
174*4882a593Smuzhiyun                    asyncio.start_unix_server(self.handle_client, os.path.basename(path))
175*4882a593Smuzhiyun                )
176*4882a593Smuzhiyun            finally:
177*4882a593Smuzhiyun                os.chdir(cwd)
178*4882a593Smuzhiyun
179*4882a593Smuzhiyun            self.logger.debug('Listening on %r' % path)
180*4882a593Smuzhiyun
181*4882a593Smuzhiyun            self._cleanup_socket = cleanup
182*4882a593Smuzhiyun            self.address = "unix://%s" % os.path.abspath(path)
183*4882a593Smuzhiyun
184*4882a593Smuzhiyun        self.start = start_unix
185*4882a593Smuzhiyun
186*4882a593Smuzhiyun    @abc.abstractmethod
187*4882a593Smuzhiyun    def accept_client(self, reader, writer):
188*4882a593Smuzhiyun        pass
189*4882a593Smuzhiyun
190*4882a593Smuzhiyun    async def handle_client(self, reader, writer):
191*4882a593Smuzhiyun        # writer.transport.set_write_buffer_limits(0)
192*4882a593Smuzhiyun        try:
193*4882a593Smuzhiyun            client = self.accept_client(reader, writer)
194*4882a593Smuzhiyun            await client.process_requests()
195*4882a593Smuzhiyun        except Exception as e:
196*4882a593Smuzhiyun            import traceback
197*4882a593Smuzhiyun            self.logger.error('Error from client: %s' % str(e), exc_info=True)
198*4882a593Smuzhiyun            traceback.print_exc()
199*4882a593Smuzhiyun            writer.close()
200*4882a593Smuzhiyun        self.logger.debug('Client disconnected')
201*4882a593Smuzhiyun
202*4882a593Smuzhiyun    def run_loop_forever(self):
203*4882a593Smuzhiyun        try:
204*4882a593Smuzhiyun            self.loop.run_forever()
205*4882a593Smuzhiyun        except KeyboardInterrupt:
206*4882a593Smuzhiyun            pass
207*4882a593Smuzhiyun
208*4882a593Smuzhiyun    def signal_handler(self):
209*4882a593Smuzhiyun        self.logger.debug("Got exit signal")
210*4882a593Smuzhiyun        self.loop.stop()
211*4882a593Smuzhiyun
212*4882a593Smuzhiyun    def _serve_forever(self):
213*4882a593Smuzhiyun        try:
214*4882a593Smuzhiyun            self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
215*4882a593Smuzhiyun            signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
216*4882a593Smuzhiyun
217*4882a593Smuzhiyun            self.run_loop_forever()
218*4882a593Smuzhiyun            self.server.close()
219*4882a593Smuzhiyun
220*4882a593Smuzhiyun            self.loop.run_until_complete(self.server.wait_closed())
221*4882a593Smuzhiyun            self.logger.debug('Server shutting down')
222*4882a593Smuzhiyun        finally:
223*4882a593Smuzhiyun            if self._cleanup_socket is not None:
224*4882a593Smuzhiyun                self._cleanup_socket()
225*4882a593Smuzhiyun
226*4882a593Smuzhiyun    def serve_forever(self):
227*4882a593Smuzhiyun        """
228*4882a593Smuzhiyun        Serve requests in the current process
229*4882a593Smuzhiyun        """
230*4882a593Smuzhiyun        # Create loop and override any loop that may have existed in
231*4882a593Smuzhiyun        # a parent process.  It is possible that the usecases of
232*4882a593Smuzhiyun        # serve_forever might be constrained enough to allow using
233*4882a593Smuzhiyun        # get_event_loop here, but better safe than sorry for now.
234*4882a593Smuzhiyun        self.loop = asyncio.new_event_loop()
235*4882a593Smuzhiyun        asyncio.set_event_loop(self.loop)
236*4882a593Smuzhiyun        self.start()
237*4882a593Smuzhiyun        self._serve_forever()
238*4882a593Smuzhiyun
239*4882a593Smuzhiyun    def serve_as_process(self, *, prefunc=None, args=()):
240*4882a593Smuzhiyun        """
241*4882a593Smuzhiyun        Serve requests in a child process
242*4882a593Smuzhiyun        """
243*4882a593Smuzhiyun        def run(queue):
244*4882a593Smuzhiyun            # Create loop and override any loop that may have existed
245*4882a593Smuzhiyun            # in a parent process.  Without doing this and instead
246*4882a593Smuzhiyun            # using get_event_loop, at the very minimum the hashserv
247*4882a593Smuzhiyun            # unit tests will hang when running the second test.
248*4882a593Smuzhiyun            # This happens since get_event_loop in the spawned server
249*4882a593Smuzhiyun            # process for the second testcase ends up with the loop
250*4882a593Smuzhiyun            # from the hashserv client created in the unit test process
251*4882a593Smuzhiyun            # when running the first testcase.  The problem is somewhat
252*4882a593Smuzhiyun            # more general, though, as any potential use of asyncio in
253*4882a593Smuzhiyun            # Cooker could create a loop that needs to replaced in this
254*4882a593Smuzhiyun            # new process.
255*4882a593Smuzhiyun            self.loop = asyncio.new_event_loop()
256*4882a593Smuzhiyun            asyncio.set_event_loop(self.loop)
257*4882a593Smuzhiyun            try:
258*4882a593Smuzhiyun                self.start()
259*4882a593Smuzhiyun            finally:
260*4882a593Smuzhiyun                queue.put(self.address)
261*4882a593Smuzhiyun                queue.close()
262*4882a593Smuzhiyun
263*4882a593Smuzhiyun            if prefunc is not None:
264*4882a593Smuzhiyun                prefunc(self, *args)
265*4882a593Smuzhiyun
266*4882a593Smuzhiyun            self._serve_forever()
267*4882a593Smuzhiyun
268*4882a593Smuzhiyun            if sys.version_info >= (3, 6):
269*4882a593Smuzhiyun                self.loop.run_until_complete(self.loop.shutdown_asyncgens())
270*4882a593Smuzhiyun            self.loop.close()
271*4882a593Smuzhiyun
272*4882a593Smuzhiyun        queue = multiprocessing.Queue()
273*4882a593Smuzhiyun
274*4882a593Smuzhiyun        # Temporarily block SIGTERM. The server process will inherit this
275*4882a593Smuzhiyun        # block which will ensure it doesn't receive the SIGTERM until the
276*4882a593Smuzhiyun        # handler is ready for it
277*4882a593Smuzhiyun        mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
278*4882a593Smuzhiyun        try:
279*4882a593Smuzhiyun            self.process = multiprocessing.Process(target=run, args=(queue,))
280*4882a593Smuzhiyun            self.process.start()
281*4882a593Smuzhiyun
282*4882a593Smuzhiyun            self.address = queue.get()
283*4882a593Smuzhiyun            queue.close()
284*4882a593Smuzhiyun            queue.join_thread()
285*4882a593Smuzhiyun
286*4882a593Smuzhiyun            return self.process
287*4882a593Smuzhiyun        finally:
288*4882a593Smuzhiyun            signal.pthread_sigmask(signal.SIG_SETMASK, mask)
289