xref: /OK3568_Linux_fs/kernel/tools/testing/vsock/util.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun  * vsock test utilities
4*4882a593Smuzhiyun  *
5*4882a593Smuzhiyun  * Copyright (C) 2017 Red Hat, Inc.
6*4882a593Smuzhiyun  *
7*4882a593Smuzhiyun  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8*4882a593Smuzhiyun  */
9*4882a593Smuzhiyun 
10*4882a593Smuzhiyun #include <errno.h>
11*4882a593Smuzhiyun #include <stdio.h>
12*4882a593Smuzhiyun #include <stdint.h>
13*4882a593Smuzhiyun #include <stdlib.h>
14*4882a593Smuzhiyun #include <signal.h>
15*4882a593Smuzhiyun #include <unistd.h>
16*4882a593Smuzhiyun #include <assert.h>
17*4882a593Smuzhiyun #include <sys/epoll.h>
18*4882a593Smuzhiyun 
19*4882a593Smuzhiyun #include "timeout.h"
20*4882a593Smuzhiyun #include "control.h"
21*4882a593Smuzhiyun #include "util.h"
22*4882a593Smuzhiyun 
23*4882a593Smuzhiyun /* Install signal handlers */
init_signals(void)24*4882a593Smuzhiyun void init_signals(void)
25*4882a593Smuzhiyun {
26*4882a593Smuzhiyun 	struct sigaction act = {
27*4882a593Smuzhiyun 		.sa_handler = sigalrm,
28*4882a593Smuzhiyun 	};
29*4882a593Smuzhiyun 
30*4882a593Smuzhiyun 	sigaction(SIGALRM, &act, NULL);
31*4882a593Smuzhiyun 	signal(SIGPIPE, SIG_IGN);
32*4882a593Smuzhiyun }
33*4882a593Smuzhiyun 
34*4882a593Smuzhiyun /* Parse a CID in string representation */
parse_cid(const char * str)35*4882a593Smuzhiyun unsigned int parse_cid(const char *str)
36*4882a593Smuzhiyun {
37*4882a593Smuzhiyun 	char *endptr = NULL;
38*4882a593Smuzhiyun 	unsigned long n;
39*4882a593Smuzhiyun 
40*4882a593Smuzhiyun 	errno = 0;
41*4882a593Smuzhiyun 	n = strtoul(str, &endptr, 10);
42*4882a593Smuzhiyun 	if (errno || *endptr != '\0') {
43*4882a593Smuzhiyun 		fprintf(stderr, "malformed CID \"%s\"\n", str);
44*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
45*4882a593Smuzhiyun 	}
46*4882a593Smuzhiyun 	return n;
47*4882a593Smuzhiyun }
48*4882a593Smuzhiyun 
49*4882a593Smuzhiyun /* Wait for the remote to close the connection */
vsock_wait_remote_close(int fd)50*4882a593Smuzhiyun void vsock_wait_remote_close(int fd)
51*4882a593Smuzhiyun {
52*4882a593Smuzhiyun 	struct epoll_event ev;
53*4882a593Smuzhiyun 	int epollfd, nfds;
54*4882a593Smuzhiyun 
55*4882a593Smuzhiyun 	epollfd = epoll_create1(0);
56*4882a593Smuzhiyun 	if (epollfd == -1) {
57*4882a593Smuzhiyun 		perror("epoll_create1");
58*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
59*4882a593Smuzhiyun 	}
60*4882a593Smuzhiyun 
61*4882a593Smuzhiyun 	ev.events = EPOLLRDHUP | EPOLLHUP;
62*4882a593Smuzhiyun 	ev.data.fd = fd;
63*4882a593Smuzhiyun 	if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
64*4882a593Smuzhiyun 		perror("epoll_ctl");
65*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
66*4882a593Smuzhiyun 	}
67*4882a593Smuzhiyun 
68*4882a593Smuzhiyun 	nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
69*4882a593Smuzhiyun 	if (nfds == -1) {
70*4882a593Smuzhiyun 		perror("epoll_wait");
71*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
72*4882a593Smuzhiyun 	}
73*4882a593Smuzhiyun 
74*4882a593Smuzhiyun 	if (nfds == 0) {
75*4882a593Smuzhiyun 		fprintf(stderr, "epoll_wait timed out\n");
76*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
77*4882a593Smuzhiyun 	}
78*4882a593Smuzhiyun 
79*4882a593Smuzhiyun 	assert(nfds == 1);
80*4882a593Smuzhiyun 	assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
81*4882a593Smuzhiyun 	assert(ev.data.fd == fd);
82*4882a593Smuzhiyun 
83*4882a593Smuzhiyun 	close(epollfd);
84*4882a593Smuzhiyun }
85*4882a593Smuzhiyun 
86*4882a593Smuzhiyun /* Connect to <cid, port> and return the file descriptor. */
vsock_stream_connect(unsigned int cid,unsigned int port)87*4882a593Smuzhiyun int vsock_stream_connect(unsigned int cid, unsigned int port)
88*4882a593Smuzhiyun {
89*4882a593Smuzhiyun 	union {
90*4882a593Smuzhiyun 		struct sockaddr sa;
91*4882a593Smuzhiyun 		struct sockaddr_vm svm;
92*4882a593Smuzhiyun 	} addr = {
93*4882a593Smuzhiyun 		.svm = {
94*4882a593Smuzhiyun 			.svm_family = AF_VSOCK,
95*4882a593Smuzhiyun 			.svm_port = port,
96*4882a593Smuzhiyun 			.svm_cid = cid,
97*4882a593Smuzhiyun 		},
98*4882a593Smuzhiyun 	};
99*4882a593Smuzhiyun 	int ret;
100*4882a593Smuzhiyun 	int fd;
101*4882a593Smuzhiyun 
102*4882a593Smuzhiyun 	control_expectln("LISTENING");
103*4882a593Smuzhiyun 
104*4882a593Smuzhiyun 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
105*4882a593Smuzhiyun 
106*4882a593Smuzhiyun 	timeout_begin(TIMEOUT);
107*4882a593Smuzhiyun 	do {
108*4882a593Smuzhiyun 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
109*4882a593Smuzhiyun 		timeout_check("connect");
110*4882a593Smuzhiyun 	} while (ret < 0 && errno == EINTR);
111*4882a593Smuzhiyun 	timeout_end();
112*4882a593Smuzhiyun 
113*4882a593Smuzhiyun 	if (ret < 0) {
114*4882a593Smuzhiyun 		int old_errno = errno;
115*4882a593Smuzhiyun 
116*4882a593Smuzhiyun 		close(fd);
117*4882a593Smuzhiyun 		fd = -1;
118*4882a593Smuzhiyun 		errno = old_errno;
119*4882a593Smuzhiyun 	}
120*4882a593Smuzhiyun 	return fd;
121*4882a593Smuzhiyun }
122*4882a593Smuzhiyun 
123*4882a593Smuzhiyun /* Listen on <cid, port> and return the first incoming connection.  The remote
124*4882a593Smuzhiyun  * address is stored to clientaddrp.  clientaddrp may be NULL.
125*4882a593Smuzhiyun  */
vsock_stream_accept(unsigned int cid,unsigned int port,struct sockaddr_vm * clientaddrp)126*4882a593Smuzhiyun int vsock_stream_accept(unsigned int cid, unsigned int port,
127*4882a593Smuzhiyun 			struct sockaddr_vm *clientaddrp)
128*4882a593Smuzhiyun {
129*4882a593Smuzhiyun 	union {
130*4882a593Smuzhiyun 		struct sockaddr sa;
131*4882a593Smuzhiyun 		struct sockaddr_vm svm;
132*4882a593Smuzhiyun 	} addr = {
133*4882a593Smuzhiyun 		.svm = {
134*4882a593Smuzhiyun 			.svm_family = AF_VSOCK,
135*4882a593Smuzhiyun 			.svm_port = port,
136*4882a593Smuzhiyun 			.svm_cid = cid,
137*4882a593Smuzhiyun 		},
138*4882a593Smuzhiyun 	};
139*4882a593Smuzhiyun 	union {
140*4882a593Smuzhiyun 		struct sockaddr sa;
141*4882a593Smuzhiyun 		struct sockaddr_vm svm;
142*4882a593Smuzhiyun 	} clientaddr;
143*4882a593Smuzhiyun 	socklen_t clientaddr_len = sizeof(clientaddr.svm);
144*4882a593Smuzhiyun 	int fd;
145*4882a593Smuzhiyun 	int client_fd;
146*4882a593Smuzhiyun 	int old_errno;
147*4882a593Smuzhiyun 
148*4882a593Smuzhiyun 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
149*4882a593Smuzhiyun 
150*4882a593Smuzhiyun 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
151*4882a593Smuzhiyun 		perror("bind");
152*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
153*4882a593Smuzhiyun 	}
154*4882a593Smuzhiyun 
155*4882a593Smuzhiyun 	if (listen(fd, 1) < 0) {
156*4882a593Smuzhiyun 		perror("listen");
157*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
158*4882a593Smuzhiyun 	}
159*4882a593Smuzhiyun 
160*4882a593Smuzhiyun 	control_writeln("LISTENING");
161*4882a593Smuzhiyun 
162*4882a593Smuzhiyun 	timeout_begin(TIMEOUT);
163*4882a593Smuzhiyun 	do {
164*4882a593Smuzhiyun 		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
165*4882a593Smuzhiyun 		timeout_check("accept");
166*4882a593Smuzhiyun 	} while (client_fd < 0 && errno == EINTR);
167*4882a593Smuzhiyun 	timeout_end();
168*4882a593Smuzhiyun 
169*4882a593Smuzhiyun 	old_errno = errno;
170*4882a593Smuzhiyun 	close(fd);
171*4882a593Smuzhiyun 	errno = old_errno;
172*4882a593Smuzhiyun 
173*4882a593Smuzhiyun 	if (client_fd < 0)
174*4882a593Smuzhiyun 		return client_fd;
175*4882a593Smuzhiyun 
176*4882a593Smuzhiyun 	if (clientaddr_len != sizeof(clientaddr.svm)) {
177*4882a593Smuzhiyun 		fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
178*4882a593Smuzhiyun 			(size_t)clientaddr_len);
179*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
180*4882a593Smuzhiyun 	}
181*4882a593Smuzhiyun 	if (clientaddr.sa.sa_family != AF_VSOCK) {
182*4882a593Smuzhiyun 		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
183*4882a593Smuzhiyun 			clientaddr.sa.sa_family);
184*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
185*4882a593Smuzhiyun 	}
186*4882a593Smuzhiyun 
187*4882a593Smuzhiyun 	if (clientaddrp)
188*4882a593Smuzhiyun 		*clientaddrp = clientaddr.svm;
189*4882a593Smuzhiyun 	return client_fd;
190*4882a593Smuzhiyun }
191*4882a593Smuzhiyun 
192*4882a593Smuzhiyun /* Transmit one byte and check the return value.
193*4882a593Smuzhiyun  *
194*4882a593Smuzhiyun  * expected_ret:
195*4882a593Smuzhiyun  *  <0 Negative errno (for testing errors)
196*4882a593Smuzhiyun  *   0 End-of-file
197*4882a593Smuzhiyun  *   1 Success
198*4882a593Smuzhiyun  */
send_byte(int fd,int expected_ret,int flags)199*4882a593Smuzhiyun void send_byte(int fd, int expected_ret, int flags)
200*4882a593Smuzhiyun {
201*4882a593Smuzhiyun 	const uint8_t byte = 'A';
202*4882a593Smuzhiyun 	ssize_t nwritten;
203*4882a593Smuzhiyun 
204*4882a593Smuzhiyun 	timeout_begin(TIMEOUT);
205*4882a593Smuzhiyun 	do {
206*4882a593Smuzhiyun 		nwritten = send(fd, &byte, sizeof(byte), flags);
207*4882a593Smuzhiyun 		timeout_check("write");
208*4882a593Smuzhiyun 	} while (nwritten < 0 && errno == EINTR);
209*4882a593Smuzhiyun 	timeout_end();
210*4882a593Smuzhiyun 
211*4882a593Smuzhiyun 	if (expected_ret < 0) {
212*4882a593Smuzhiyun 		if (nwritten != -1) {
213*4882a593Smuzhiyun 			fprintf(stderr, "bogus send(2) return value %zd\n",
214*4882a593Smuzhiyun 				nwritten);
215*4882a593Smuzhiyun 			exit(EXIT_FAILURE);
216*4882a593Smuzhiyun 		}
217*4882a593Smuzhiyun 		if (errno != -expected_ret) {
218*4882a593Smuzhiyun 			perror("write");
219*4882a593Smuzhiyun 			exit(EXIT_FAILURE);
220*4882a593Smuzhiyun 		}
221*4882a593Smuzhiyun 		return;
222*4882a593Smuzhiyun 	}
223*4882a593Smuzhiyun 
224*4882a593Smuzhiyun 	if (nwritten < 0) {
225*4882a593Smuzhiyun 		perror("write");
226*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
227*4882a593Smuzhiyun 	}
228*4882a593Smuzhiyun 	if (nwritten == 0) {
229*4882a593Smuzhiyun 		if (expected_ret == 0)
230*4882a593Smuzhiyun 			return;
231*4882a593Smuzhiyun 
232*4882a593Smuzhiyun 		fprintf(stderr, "unexpected EOF while sending byte\n");
233*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
234*4882a593Smuzhiyun 	}
235*4882a593Smuzhiyun 	if (nwritten != sizeof(byte)) {
236*4882a593Smuzhiyun 		fprintf(stderr, "bogus send(2) return value %zd\n", nwritten);
237*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
238*4882a593Smuzhiyun 	}
239*4882a593Smuzhiyun }
240*4882a593Smuzhiyun 
241*4882a593Smuzhiyun /* Receive one byte and check the return value.
242*4882a593Smuzhiyun  *
243*4882a593Smuzhiyun  * expected_ret:
244*4882a593Smuzhiyun  *  <0 Negative errno (for testing errors)
245*4882a593Smuzhiyun  *   0 End-of-file
246*4882a593Smuzhiyun  *   1 Success
247*4882a593Smuzhiyun  */
recv_byte(int fd,int expected_ret,int flags)248*4882a593Smuzhiyun void recv_byte(int fd, int expected_ret, int flags)
249*4882a593Smuzhiyun {
250*4882a593Smuzhiyun 	uint8_t byte;
251*4882a593Smuzhiyun 	ssize_t nread;
252*4882a593Smuzhiyun 
253*4882a593Smuzhiyun 	timeout_begin(TIMEOUT);
254*4882a593Smuzhiyun 	do {
255*4882a593Smuzhiyun 		nread = recv(fd, &byte, sizeof(byte), flags);
256*4882a593Smuzhiyun 		timeout_check("read");
257*4882a593Smuzhiyun 	} while (nread < 0 && errno == EINTR);
258*4882a593Smuzhiyun 	timeout_end();
259*4882a593Smuzhiyun 
260*4882a593Smuzhiyun 	if (expected_ret < 0) {
261*4882a593Smuzhiyun 		if (nread != -1) {
262*4882a593Smuzhiyun 			fprintf(stderr, "bogus recv(2) return value %zd\n",
263*4882a593Smuzhiyun 				nread);
264*4882a593Smuzhiyun 			exit(EXIT_FAILURE);
265*4882a593Smuzhiyun 		}
266*4882a593Smuzhiyun 		if (errno != -expected_ret) {
267*4882a593Smuzhiyun 			perror("read");
268*4882a593Smuzhiyun 			exit(EXIT_FAILURE);
269*4882a593Smuzhiyun 		}
270*4882a593Smuzhiyun 		return;
271*4882a593Smuzhiyun 	}
272*4882a593Smuzhiyun 
273*4882a593Smuzhiyun 	if (nread < 0) {
274*4882a593Smuzhiyun 		perror("read");
275*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
276*4882a593Smuzhiyun 	}
277*4882a593Smuzhiyun 	if (nread == 0) {
278*4882a593Smuzhiyun 		if (expected_ret == 0)
279*4882a593Smuzhiyun 			return;
280*4882a593Smuzhiyun 
281*4882a593Smuzhiyun 		fprintf(stderr, "unexpected EOF while receiving byte\n");
282*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
283*4882a593Smuzhiyun 	}
284*4882a593Smuzhiyun 	if (nread != sizeof(byte)) {
285*4882a593Smuzhiyun 		fprintf(stderr, "bogus recv(2) return value %zd\n", nread);
286*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
287*4882a593Smuzhiyun 	}
288*4882a593Smuzhiyun 	if (byte != 'A') {
289*4882a593Smuzhiyun 		fprintf(stderr, "unexpected byte read %c\n", byte);
290*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
291*4882a593Smuzhiyun 	}
292*4882a593Smuzhiyun }
293*4882a593Smuzhiyun 
294*4882a593Smuzhiyun /* Run test cases.  The program terminates if a failure occurs. */
run_tests(const struct test_case * test_cases,const struct test_opts * opts)295*4882a593Smuzhiyun void run_tests(const struct test_case *test_cases,
296*4882a593Smuzhiyun 	       const struct test_opts *opts)
297*4882a593Smuzhiyun {
298*4882a593Smuzhiyun 	int i;
299*4882a593Smuzhiyun 
300*4882a593Smuzhiyun 	for (i = 0; test_cases[i].name; i++) {
301*4882a593Smuzhiyun 		void (*run)(const struct test_opts *opts);
302*4882a593Smuzhiyun 		char *line;
303*4882a593Smuzhiyun 
304*4882a593Smuzhiyun 		printf("%d - %s...", i, test_cases[i].name);
305*4882a593Smuzhiyun 		fflush(stdout);
306*4882a593Smuzhiyun 
307*4882a593Smuzhiyun 		/* Full barrier before executing the next test.  This
308*4882a593Smuzhiyun 		 * ensures that client and server are executing the
309*4882a593Smuzhiyun 		 * same test case.  In particular, it means whoever is
310*4882a593Smuzhiyun 		 * faster will not see the peer still executing the
311*4882a593Smuzhiyun 		 * last test.  This is important because port numbers
312*4882a593Smuzhiyun 		 * can be used by multiple test cases.
313*4882a593Smuzhiyun 		 */
314*4882a593Smuzhiyun 		if (test_cases[i].skip)
315*4882a593Smuzhiyun 			control_writeln("SKIP");
316*4882a593Smuzhiyun 		else
317*4882a593Smuzhiyun 			control_writeln("NEXT");
318*4882a593Smuzhiyun 
319*4882a593Smuzhiyun 		line = control_readln();
320*4882a593Smuzhiyun 		if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) {
321*4882a593Smuzhiyun 
322*4882a593Smuzhiyun 			printf("skipped\n");
323*4882a593Smuzhiyun 
324*4882a593Smuzhiyun 			free(line);
325*4882a593Smuzhiyun 			continue;
326*4882a593Smuzhiyun 		}
327*4882a593Smuzhiyun 
328*4882a593Smuzhiyun 		control_cmpln(line, "NEXT", true);
329*4882a593Smuzhiyun 		free(line);
330*4882a593Smuzhiyun 
331*4882a593Smuzhiyun 		if (opts->mode == TEST_MODE_CLIENT)
332*4882a593Smuzhiyun 			run = test_cases[i].run_client;
333*4882a593Smuzhiyun 		else
334*4882a593Smuzhiyun 			run = test_cases[i].run_server;
335*4882a593Smuzhiyun 
336*4882a593Smuzhiyun 		if (run)
337*4882a593Smuzhiyun 			run(opts);
338*4882a593Smuzhiyun 
339*4882a593Smuzhiyun 		printf("ok\n");
340*4882a593Smuzhiyun 	}
341*4882a593Smuzhiyun }
342*4882a593Smuzhiyun 
list_tests(const struct test_case * test_cases)343*4882a593Smuzhiyun void list_tests(const struct test_case *test_cases)
344*4882a593Smuzhiyun {
345*4882a593Smuzhiyun 	int i;
346*4882a593Smuzhiyun 
347*4882a593Smuzhiyun 	printf("ID\tTest name\n");
348*4882a593Smuzhiyun 
349*4882a593Smuzhiyun 	for (i = 0; test_cases[i].name; i++)
350*4882a593Smuzhiyun 		printf("%d\t%s\n", i, test_cases[i].name);
351*4882a593Smuzhiyun 
352*4882a593Smuzhiyun 	exit(EXIT_FAILURE);
353*4882a593Smuzhiyun }
354*4882a593Smuzhiyun 
skip_test(struct test_case * test_cases,size_t test_cases_len,const char * test_id_str)355*4882a593Smuzhiyun void skip_test(struct test_case *test_cases, size_t test_cases_len,
356*4882a593Smuzhiyun 	       const char *test_id_str)
357*4882a593Smuzhiyun {
358*4882a593Smuzhiyun 	unsigned long test_id;
359*4882a593Smuzhiyun 	char *endptr = NULL;
360*4882a593Smuzhiyun 
361*4882a593Smuzhiyun 	errno = 0;
362*4882a593Smuzhiyun 	test_id = strtoul(test_id_str, &endptr, 10);
363*4882a593Smuzhiyun 	if (errno || *endptr != '\0') {
364*4882a593Smuzhiyun 		fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str);
365*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
366*4882a593Smuzhiyun 	}
367*4882a593Smuzhiyun 
368*4882a593Smuzhiyun 	if (test_id >= test_cases_len) {
369*4882a593Smuzhiyun 		fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n",
370*4882a593Smuzhiyun 			test_id, test_cases_len - 1);
371*4882a593Smuzhiyun 		exit(EXIT_FAILURE);
372*4882a593Smuzhiyun 	}
373*4882a593Smuzhiyun 
374*4882a593Smuzhiyun 	test_cases[test_id].skip = true;
375*4882a593Smuzhiyun }
376