xref: /OK3568_Linux_fs/yocto/poky/bitbake/bin/bitbake-hashclient (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1#! /usr/bin/env python3
2#
3# Copyright (C) 2019 Garmin Ltd.
4#
5# SPDX-License-Identifier: GPL-2.0-only
6#
7
8import argparse
9import hashlib
10import logging
11import os
12import pprint
13import sys
14import threading
15import time
16import warnings
17warnings.simplefilter("default")
18
19try:
20    import tqdm
21    ProgressBar = tqdm.tqdm
22except ImportError:
23    class ProgressBar(object):
24        def __init__(self, *args, **kwargs):
25            pass
26
27        def __enter__(self):
28            return self
29
30        def __exit__(self, *args, **kwargs):
31            pass
32
33        def update(self):
34            pass
35
36sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))
37
38import hashserv
39
40DEFAULT_ADDRESS = 'unix://./hashserve.sock'
41METHOD = 'stress.test.method'
42
43
44def main():
45    def handle_stats(args, client):
46        if args.reset:
47            s = client.reset_stats()
48        else:
49            s = client.get_stats()
50        pprint.pprint(s)
51        return 0
52
53    def handle_stress(args, client):
54        def thread_main(pbar, lock):
55            nonlocal found_hashes
56            nonlocal missed_hashes
57            nonlocal max_time
58
59            client = hashserv.create_client(args.address)
60
61            for i in range(args.requests):
62                taskhash = hashlib.sha256()
63                taskhash.update(args.taskhash_seed.encode('utf-8'))
64                taskhash.update(str(i).encode('utf-8'))
65
66                start_time = time.perf_counter()
67                l = client.get_unihash(METHOD, taskhash.hexdigest())
68                elapsed = time.perf_counter() - start_time
69
70                with lock:
71                    if l:
72                        found_hashes += 1
73                    else:
74                        missed_hashes += 1
75
76                    max_time = max(elapsed, max_time)
77                    pbar.update()
78
79        max_time = 0
80        found_hashes = 0
81        missed_hashes = 0
82        lock = threading.Lock()
83        total_requests = args.clients * args.requests
84        start_time = time.perf_counter()
85        with ProgressBar(total=total_requests) as pbar:
86            threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)]
87            for t in threads:
88                t.start()
89
90            for t in threads:
91                t.join()
92
93        elapsed = time.perf_counter() - start_time
94        with lock:
95            print("%d requests in %.1fs. %.1f requests per second" % (total_requests, elapsed, total_requests / elapsed))
96            print("Average request time %.8fs" % (elapsed / total_requests))
97            print("Max request time was %.8fs" % max_time)
98            print("Found %d hashes, missed %d" % (found_hashes, missed_hashes))
99
100        if args.report:
101            with ProgressBar(total=args.requests) as pbar:
102                for i in range(args.requests):
103                    taskhash = hashlib.sha256()
104                    taskhash.update(args.taskhash_seed.encode('utf-8'))
105                    taskhash.update(str(i).encode('utf-8'))
106
107                    outhash = hashlib.sha256()
108                    outhash.update(args.outhash_seed.encode('utf-8'))
109                    outhash.update(str(i).encode('utf-8'))
110
111                    client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest())
112
113                    with lock:
114                        pbar.update()
115
116    parser = argparse.ArgumentParser(description='Hash Equivalence Client')
117    parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
118    parser.add_argument('--log', default='WARNING', help='Set logging level')
119
120    subparsers = parser.add_subparsers()
121
122    stats_parser = subparsers.add_parser('stats', help='Show server stats')
123    stats_parser.add_argument('--reset', action='store_true',
124                              help='Reset server stats')
125    stats_parser.set_defaults(func=handle_stats)
126
127    stress_parser = subparsers.add_parser('stress', help='Run stress test')
128    stress_parser.add_argument('--clients', type=int, default=10,
129                               help='Number of simultaneous clients')
130    stress_parser.add_argument('--requests', type=int, default=1000,
131                               help='Number of requests each client will perform')
132    stress_parser.add_argument('--report', action='store_true',
133                               help='Report new hashes')
134    stress_parser.add_argument('--taskhash-seed', default='',
135                               help='Include string in taskhash')
136    stress_parser.add_argument('--outhash-seed', default='',
137                               help='Include string in outhash')
138    stress_parser.set_defaults(func=handle_stress)
139
140    args = parser.parse_args()
141
142    logger = logging.getLogger('hashserv')
143
144    level = getattr(logging, args.log.upper(), None)
145    if not isinstance(level, int):
146        raise ValueError('Invalid log level: %s' % args.log)
147
148    logger.setLevel(level)
149    console = logging.StreamHandler()
150    console.setLevel(level)
151    logger.addHandler(console)
152
153    func = getattr(args, 'func', None)
154    if func:
155        client = hashserv.create_client(args.address)
156
157        return func(args, client)
158
159    return 0
160
161
162if __name__ == '__main__':
163    try:
164        ret = main()
165    except Exception:
166        ret = 1
167        import traceback
168        traceback.print_exc()
169    sys.exit(ret)
170