1#! /usr/bin/env python3 2# 3# Copyright (C) 2019 Garmin Ltd. 4# 5# SPDX-License-Identifier: GPL-2.0-only 6# 7 8import argparse 9import hashlib 10import logging 11import os 12import pprint 13import sys 14import threading 15import time 16import warnings 17warnings.simplefilter("default") 18 19try: 20 import tqdm 21 ProgressBar = tqdm.tqdm 22except ImportError: 23 class ProgressBar(object): 24 def __init__(self, *args, **kwargs): 25 pass 26 27 def __enter__(self): 28 return self 29 30 def __exit__(self, *args, **kwargs): 31 pass 32 33 def update(self): 34 pass 35 36sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib')) 37 38import hashserv 39 40DEFAULT_ADDRESS = 'unix://./hashserve.sock' 41METHOD = 'stress.test.method' 42 43 44def main(): 45 def handle_stats(args, client): 46 if args.reset: 47 s = client.reset_stats() 48 else: 49 s = client.get_stats() 50 pprint.pprint(s) 51 return 0 52 53 def handle_stress(args, client): 54 def thread_main(pbar, lock): 55 nonlocal found_hashes 56 nonlocal missed_hashes 57 nonlocal max_time 58 59 client = hashserv.create_client(args.address) 60 61 for i in range(args.requests): 62 taskhash = hashlib.sha256() 63 taskhash.update(args.taskhash_seed.encode('utf-8')) 64 taskhash.update(str(i).encode('utf-8')) 65 66 start_time = time.perf_counter() 67 l = client.get_unihash(METHOD, taskhash.hexdigest()) 68 elapsed = time.perf_counter() - start_time 69 70 with lock: 71 if l: 72 found_hashes += 1 73 else: 74 missed_hashes += 1 75 76 max_time = max(elapsed, max_time) 77 pbar.update() 78 79 max_time = 0 80 found_hashes = 0 81 missed_hashes = 0 82 lock = threading.Lock() 83 total_requests = args.clients * args.requests 84 start_time = time.perf_counter() 85 with ProgressBar(total=total_requests) as pbar: 86 threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)] 87 for t in threads: 88 t.start() 89 90 for t in threads: 91 t.join() 92 93 elapsed = time.perf_counter() - start_time 94 with lock: 95 print("%d requests in %.1fs. %.1f requests per second" % (total_requests, elapsed, total_requests / elapsed)) 96 print("Average request time %.8fs" % (elapsed / total_requests)) 97 print("Max request time was %.8fs" % max_time) 98 print("Found %d hashes, missed %d" % (found_hashes, missed_hashes)) 99 100 if args.report: 101 with ProgressBar(total=args.requests) as pbar: 102 for i in range(args.requests): 103 taskhash = hashlib.sha256() 104 taskhash.update(args.taskhash_seed.encode('utf-8')) 105 taskhash.update(str(i).encode('utf-8')) 106 107 outhash = hashlib.sha256() 108 outhash.update(args.outhash_seed.encode('utf-8')) 109 outhash.update(str(i).encode('utf-8')) 110 111 client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest()) 112 113 with lock: 114 pbar.update() 115 116 parser = argparse.ArgumentParser(description='Hash Equivalence Client') 117 parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")') 118 parser.add_argument('--log', default='WARNING', help='Set logging level') 119 120 subparsers = parser.add_subparsers() 121 122 stats_parser = subparsers.add_parser('stats', help='Show server stats') 123 stats_parser.add_argument('--reset', action='store_true', 124 help='Reset server stats') 125 stats_parser.set_defaults(func=handle_stats) 126 127 stress_parser = subparsers.add_parser('stress', help='Run stress test') 128 stress_parser.add_argument('--clients', type=int, default=10, 129 help='Number of simultaneous clients') 130 stress_parser.add_argument('--requests', type=int, default=1000, 131 help='Number of requests each client will perform') 132 stress_parser.add_argument('--report', action='store_true', 133 help='Report new hashes') 134 stress_parser.add_argument('--taskhash-seed', default='', 135 help='Include string in taskhash') 136 stress_parser.add_argument('--outhash-seed', default='', 137 help='Include string in outhash') 138 stress_parser.set_defaults(func=handle_stress) 139 140 args = parser.parse_args() 141 142 logger = logging.getLogger('hashserv') 143 144 level = getattr(logging, args.log.upper(), None) 145 if not isinstance(level, int): 146 raise ValueError('Invalid log level: %s' % args.log) 147 148 logger.setLevel(level) 149 console = logging.StreamHandler() 150 console.setLevel(level) 151 logger.addHandler(console) 152 153 func = getattr(args, 'func', None) 154 if func: 155 client = hashserv.create_client(args.address) 156 157 return func(args, client) 158 159 return 0 160 161 162if __name__ == '__main__': 163 try: 164 ret = main() 165 except Exception: 166 ret = 1 167 import traceback 168 traceback.print_exc() 169 sys.exit(ret) 170