1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun * vsock test utilities
4*4882a593Smuzhiyun *
5*4882a593Smuzhiyun * Copyright (C) 2017 Red Hat, Inc.
6*4882a593Smuzhiyun *
7*4882a593Smuzhiyun * Author: Stefan Hajnoczi <stefanha@redhat.com>
8*4882a593Smuzhiyun */
9*4882a593Smuzhiyun
10*4882a593Smuzhiyun #include <errno.h>
11*4882a593Smuzhiyun #include <stdio.h>
12*4882a593Smuzhiyun #include <stdint.h>
13*4882a593Smuzhiyun #include <stdlib.h>
14*4882a593Smuzhiyun #include <signal.h>
15*4882a593Smuzhiyun #include <unistd.h>
16*4882a593Smuzhiyun #include <assert.h>
17*4882a593Smuzhiyun #include <sys/epoll.h>
18*4882a593Smuzhiyun
19*4882a593Smuzhiyun #include "timeout.h"
20*4882a593Smuzhiyun #include "control.h"
21*4882a593Smuzhiyun #include "util.h"
22*4882a593Smuzhiyun
23*4882a593Smuzhiyun /* Install signal handlers */
init_signals(void)24*4882a593Smuzhiyun void init_signals(void)
25*4882a593Smuzhiyun {
26*4882a593Smuzhiyun struct sigaction act = {
27*4882a593Smuzhiyun .sa_handler = sigalrm,
28*4882a593Smuzhiyun };
29*4882a593Smuzhiyun
30*4882a593Smuzhiyun sigaction(SIGALRM, &act, NULL);
31*4882a593Smuzhiyun signal(SIGPIPE, SIG_IGN);
32*4882a593Smuzhiyun }
33*4882a593Smuzhiyun
34*4882a593Smuzhiyun /* Parse a CID in string representation */
parse_cid(const char * str)35*4882a593Smuzhiyun unsigned int parse_cid(const char *str)
36*4882a593Smuzhiyun {
37*4882a593Smuzhiyun char *endptr = NULL;
38*4882a593Smuzhiyun unsigned long n;
39*4882a593Smuzhiyun
40*4882a593Smuzhiyun errno = 0;
41*4882a593Smuzhiyun n = strtoul(str, &endptr, 10);
42*4882a593Smuzhiyun if (errno || *endptr != '\0') {
43*4882a593Smuzhiyun fprintf(stderr, "malformed CID \"%s\"\n", str);
44*4882a593Smuzhiyun exit(EXIT_FAILURE);
45*4882a593Smuzhiyun }
46*4882a593Smuzhiyun return n;
47*4882a593Smuzhiyun }
48*4882a593Smuzhiyun
49*4882a593Smuzhiyun /* Wait for the remote to close the connection */
vsock_wait_remote_close(int fd)50*4882a593Smuzhiyun void vsock_wait_remote_close(int fd)
51*4882a593Smuzhiyun {
52*4882a593Smuzhiyun struct epoll_event ev;
53*4882a593Smuzhiyun int epollfd, nfds;
54*4882a593Smuzhiyun
55*4882a593Smuzhiyun epollfd = epoll_create1(0);
56*4882a593Smuzhiyun if (epollfd == -1) {
57*4882a593Smuzhiyun perror("epoll_create1");
58*4882a593Smuzhiyun exit(EXIT_FAILURE);
59*4882a593Smuzhiyun }
60*4882a593Smuzhiyun
61*4882a593Smuzhiyun ev.events = EPOLLRDHUP | EPOLLHUP;
62*4882a593Smuzhiyun ev.data.fd = fd;
63*4882a593Smuzhiyun if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
64*4882a593Smuzhiyun perror("epoll_ctl");
65*4882a593Smuzhiyun exit(EXIT_FAILURE);
66*4882a593Smuzhiyun }
67*4882a593Smuzhiyun
68*4882a593Smuzhiyun nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
69*4882a593Smuzhiyun if (nfds == -1) {
70*4882a593Smuzhiyun perror("epoll_wait");
71*4882a593Smuzhiyun exit(EXIT_FAILURE);
72*4882a593Smuzhiyun }
73*4882a593Smuzhiyun
74*4882a593Smuzhiyun if (nfds == 0) {
75*4882a593Smuzhiyun fprintf(stderr, "epoll_wait timed out\n");
76*4882a593Smuzhiyun exit(EXIT_FAILURE);
77*4882a593Smuzhiyun }
78*4882a593Smuzhiyun
79*4882a593Smuzhiyun assert(nfds == 1);
80*4882a593Smuzhiyun assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
81*4882a593Smuzhiyun assert(ev.data.fd == fd);
82*4882a593Smuzhiyun
83*4882a593Smuzhiyun close(epollfd);
84*4882a593Smuzhiyun }
85*4882a593Smuzhiyun
86*4882a593Smuzhiyun /* Connect to <cid, port> and return the file descriptor. */
vsock_stream_connect(unsigned int cid,unsigned int port)87*4882a593Smuzhiyun int vsock_stream_connect(unsigned int cid, unsigned int port)
88*4882a593Smuzhiyun {
89*4882a593Smuzhiyun union {
90*4882a593Smuzhiyun struct sockaddr sa;
91*4882a593Smuzhiyun struct sockaddr_vm svm;
92*4882a593Smuzhiyun } addr = {
93*4882a593Smuzhiyun .svm = {
94*4882a593Smuzhiyun .svm_family = AF_VSOCK,
95*4882a593Smuzhiyun .svm_port = port,
96*4882a593Smuzhiyun .svm_cid = cid,
97*4882a593Smuzhiyun },
98*4882a593Smuzhiyun };
99*4882a593Smuzhiyun int ret;
100*4882a593Smuzhiyun int fd;
101*4882a593Smuzhiyun
102*4882a593Smuzhiyun control_expectln("LISTENING");
103*4882a593Smuzhiyun
104*4882a593Smuzhiyun fd = socket(AF_VSOCK, SOCK_STREAM, 0);
105*4882a593Smuzhiyun
106*4882a593Smuzhiyun timeout_begin(TIMEOUT);
107*4882a593Smuzhiyun do {
108*4882a593Smuzhiyun ret = connect(fd, &addr.sa, sizeof(addr.svm));
109*4882a593Smuzhiyun timeout_check("connect");
110*4882a593Smuzhiyun } while (ret < 0 && errno == EINTR);
111*4882a593Smuzhiyun timeout_end();
112*4882a593Smuzhiyun
113*4882a593Smuzhiyun if (ret < 0) {
114*4882a593Smuzhiyun int old_errno = errno;
115*4882a593Smuzhiyun
116*4882a593Smuzhiyun close(fd);
117*4882a593Smuzhiyun fd = -1;
118*4882a593Smuzhiyun errno = old_errno;
119*4882a593Smuzhiyun }
120*4882a593Smuzhiyun return fd;
121*4882a593Smuzhiyun }
122*4882a593Smuzhiyun
123*4882a593Smuzhiyun /* Listen on <cid, port> and return the first incoming connection. The remote
124*4882a593Smuzhiyun * address is stored to clientaddrp. clientaddrp may be NULL.
125*4882a593Smuzhiyun */
vsock_stream_accept(unsigned int cid,unsigned int port,struct sockaddr_vm * clientaddrp)126*4882a593Smuzhiyun int vsock_stream_accept(unsigned int cid, unsigned int port,
127*4882a593Smuzhiyun struct sockaddr_vm *clientaddrp)
128*4882a593Smuzhiyun {
129*4882a593Smuzhiyun union {
130*4882a593Smuzhiyun struct sockaddr sa;
131*4882a593Smuzhiyun struct sockaddr_vm svm;
132*4882a593Smuzhiyun } addr = {
133*4882a593Smuzhiyun .svm = {
134*4882a593Smuzhiyun .svm_family = AF_VSOCK,
135*4882a593Smuzhiyun .svm_port = port,
136*4882a593Smuzhiyun .svm_cid = cid,
137*4882a593Smuzhiyun },
138*4882a593Smuzhiyun };
139*4882a593Smuzhiyun union {
140*4882a593Smuzhiyun struct sockaddr sa;
141*4882a593Smuzhiyun struct sockaddr_vm svm;
142*4882a593Smuzhiyun } clientaddr;
143*4882a593Smuzhiyun socklen_t clientaddr_len = sizeof(clientaddr.svm);
144*4882a593Smuzhiyun int fd;
145*4882a593Smuzhiyun int client_fd;
146*4882a593Smuzhiyun int old_errno;
147*4882a593Smuzhiyun
148*4882a593Smuzhiyun fd = socket(AF_VSOCK, SOCK_STREAM, 0);
149*4882a593Smuzhiyun
150*4882a593Smuzhiyun if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
151*4882a593Smuzhiyun perror("bind");
152*4882a593Smuzhiyun exit(EXIT_FAILURE);
153*4882a593Smuzhiyun }
154*4882a593Smuzhiyun
155*4882a593Smuzhiyun if (listen(fd, 1) < 0) {
156*4882a593Smuzhiyun perror("listen");
157*4882a593Smuzhiyun exit(EXIT_FAILURE);
158*4882a593Smuzhiyun }
159*4882a593Smuzhiyun
160*4882a593Smuzhiyun control_writeln("LISTENING");
161*4882a593Smuzhiyun
162*4882a593Smuzhiyun timeout_begin(TIMEOUT);
163*4882a593Smuzhiyun do {
164*4882a593Smuzhiyun client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
165*4882a593Smuzhiyun timeout_check("accept");
166*4882a593Smuzhiyun } while (client_fd < 0 && errno == EINTR);
167*4882a593Smuzhiyun timeout_end();
168*4882a593Smuzhiyun
169*4882a593Smuzhiyun old_errno = errno;
170*4882a593Smuzhiyun close(fd);
171*4882a593Smuzhiyun errno = old_errno;
172*4882a593Smuzhiyun
173*4882a593Smuzhiyun if (client_fd < 0)
174*4882a593Smuzhiyun return client_fd;
175*4882a593Smuzhiyun
176*4882a593Smuzhiyun if (clientaddr_len != sizeof(clientaddr.svm)) {
177*4882a593Smuzhiyun fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
178*4882a593Smuzhiyun (size_t)clientaddr_len);
179*4882a593Smuzhiyun exit(EXIT_FAILURE);
180*4882a593Smuzhiyun }
181*4882a593Smuzhiyun if (clientaddr.sa.sa_family != AF_VSOCK) {
182*4882a593Smuzhiyun fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
183*4882a593Smuzhiyun clientaddr.sa.sa_family);
184*4882a593Smuzhiyun exit(EXIT_FAILURE);
185*4882a593Smuzhiyun }
186*4882a593Smuzhiyun
187*4882a593Smuzhiyun if (clientaddrp)
188*4882a593Smuzhiyun *clientaddrp = clientaddr.svm;
189*4882a593Smuzhiyun return client_fd;
190*4882a593Smuzhiyun }
191*4882a593Smuzhiyun
192*4882a593Smuzhiyun /* Transmit one byte and check the return value.
193*4882a593Smuzhiyun *
194*4882a593Smuzhiyun * expected_ret:
195*4882a593Smuzhiyun * <0 Negative errno (for testing errors)
196*4882a593Smuzhiyun * 0 End-of-file
197*4882a593Smuzhiyun * 1 Success
198*4882a593Smuzhiyun */
send_byte(int fd,int expected_ret,int flags)199*4882a593Smuzhiyun void send_byte(int fd, int expected_ret, int flags)
200*4882a593Smuzhiyun {
201*4882a593Smuzhiyun const uint8_t byte = 'A';
202*4882a593Smuzhiyun ssize_t nwritten;
203*4882a593Smuzhiyun
204*4882a593Smuzhiyun timeout_begin(TIMEOUT);
205*4882a593Smuzhiyun do {
206*4882a593Smuzhiyun nwritten = send(fd, &byte, sizeof(byte), flags);
207*4882a593Smuzhiyun timeout_check("write");
208*4882a593Smuzhiyun } while (nwritten < 0 && errno == EINTR);
209*4882a593Smuzhiyun timeout_end();
210*4882a593Smuzhiyun
211*4882a593Smuzhiyun if (expected_ret < 0) {
212*4882a593Smuzhiyun if (nwritten != -1) {
213*4882a593Smuzhiyun fprintf(stderr, "bogus send(2) return value %zd\n",
214*4882a593Smuzhiyun nwritten);
215*4882a593Smuzhiyun exit(EXIT_FAILURE);
216*4882a593Smuzhiyun }
217*4882a593Smuzhiyun if (errno != -expected_ret) {
218*4882a593Smuzhiyun perror("write");
219*4882a593Smuzhiyun exit(EXIT_FAILURE);
220*4882a593Smuzhiyun }
221*4882a593Smuzhiyun return;
222*4882a593Smuzhiyun }
223*4882a593Smuzhiyun
224*4882a593Smuzhiyun if (nwritten < 0) {
225*4882a593Smuzhiyun perror("write");
226*4882a593Smuzhiyun exit(EXIT_FAILURE);
227*4882a593Smuzhiyun }
228*4882a593Smuzhiyun if (nwritten == 0) {
229*4882a593Smuzhiyun if (expected_ret == 0)
230*4882a593Smuzhiyun return;
231*4882a593Smuzhiyun
232*4882a593Smuzhiyun fprintf(stderr, "unexpected EOF while sending byte\n");
233*4882a593Smuzhiyun exit(EXIT_FAILURE);
234*4882a593Smuzhiyun }
235*4882a593Smuzhiyun if (nwritten != sizeof(byte)) {
236*4882a593Smuzhiyun fprintf(stderr, "bogus send(2) return value %zd\n", nwritten);
237*4882a593Smuzhiyun exit(EXIT_FAILURE);
238*4882a593Smuzhiyun }
239*4882a593Smuzhiyun }
240*4882a593Smuzhiyun
241*4882a593Smuzhiyun /* Receive one byte and check the return value.
242*4882a593Smuzhiyun *
243*4882a593Smuzhiyun * expected_ret:
244*4882a593Smuzhiyun * <0 Negative errno (for testing errors)
245*4882a593Smuzhiyun * 0 End-of-file
246*4882a593Smuzhiyun * 1 Success
247*4882a593Smuzhiyun */
recv_byte(int fd,int expected_ret,int flags)248*4882a593Smuzhiyun void recv_byte(int fd, int expected_ret, int flags)
249*4882a593Smuzhiyun {
250*4882a593Smuzhiyun uint8_t byte;
251*4882a593Smuzhiyun ssize_t nread;
252*4882a593Smuzhiyun
253*4882a593Smuzhiyun timeout_begin(TIMEOUT);
254*4882a593Smuzhiyun do {
255*4882a593Smuzhiyun nread = recv(fd, &byte, sizeof(byte), flags);
256*4882a593Smuzhiyun timeout_check("read");
257*4882a593Smuzhiyun } while (nread < 0 && errno == EINTR);
258*4882a593Smuzhiyun timeout_end();
259*4882a593Smuzhiyun
260*4882a593Smuzhiyun if (expected_ret < 0) {
261*4882a593Smuzhiyun if (nread != -1) {
262*4882a593Smuzhiyun fprintf(stderr, "bogus recv(2) return value %zd\n",
263*4882a593Smuzhiyun nread);
264*4882a593Smuzhiyun exit(EXIT_FAILURE);
265*4882a593Smuzhiyun }
266*4882a593Smuzhiyun if (errno != -expected_ret) {
267*4882a593Smuzhiyun perror("read");
268*4882a593Smuzhiyun exit(EXIT_FAILURE);
269*4882a593Smuzhiyun }
270*4882a593Smuzhiyun return;
271*4882a593Smuzhiyun }
272*4882a593Smuzhiyun
273*4882a593Smuzhiyun if (nread < 0) {
274*4882a593Smuzhiyun perror("read");
275*4882a593Smuzhiyun exit(EXIT_FAILURE);
276*4882a593Smuzhiyun }
277*4882a593Smuzhiyun if (nread == 0) {
278*4882a593Smuzhiyun if (expected_ret == 0)
279*4882a593Smuzhiyun return;
280*4882a593Smuzhiyun
281*4882a593Smuzhiyun fprintf(stderr, "unexpected EOF while receiving byte\n");
282*4882a593Smuzhiyun exit(EXIT_FAILURE);
283*4882a593Smuzhiyun }
284*4882a593Smuzhiyun if (nread != sizeof(byte)) {
285*4882a593Smuzhiyun fprintf(stderr, "bogus recv(2) return value %zd\n", nread);
286*4882a593Smuzhiyun exit(EXIT_FAILURE);
287*4882a593Smuzhiyun }
288*4882a593Smuzhiyun if (byte != 'A') {
289*4882a593Smuzhiyun fprintf(stderr, "unexpected byte read %c\n", byte);
290*4882a593Smuzhiyun exit(EXIT_FAILURE);
291*4882a593Smuzhiyun }
292*4882a593Smuzhiyun }
293*4882a593Smuzhiyun
294*4882a593Smuzhiyun /* Run test cases. The program terminates if a failure occurs. */
run_tests(const struct test_case * test_cases,const struct test_opts * opts)295*4882a593Smuzhiyun void run_tests(const struct test_case *test_cases,
296*4882a593Smuzhiyun const struct test_opts *opts)
297*4882a593Smuzhiyun {
298*4882a593Smuzhiyun int i;
299*4882a593Smuzhiyun
300*4882a593Smuzhiyun for (i = 0; test_cases[i].name; i++) {
301*4882a593Smuzhiyun void (*run)(const struct test_opts *opts);
302*4882a593Smuzhiyun char *line;
303*4882a593Smuzhiyun
304*4882a593Smuzhiyun printf("%d - %s...", i, test_cases[i].name);
305*4882a593Smuzhiyun fflush(stdout);
306*4882a593Smuzhiyun
307*4882a593Smuzhiyun /* Full barrier before executing the next test. This
308*4882a593Smuzhiyun * ensures that client and server are executing the
309*4882a593Smuzhiyun * same test case. In particular, it means whoever is
310*4882a593Smuzhiyun * faster will not see the peer still executing the
311*4882a593Smuzhiyun * last test. This is important because port numbers
312*4882a593Smuzhiyun * can be used by multiple test cases.
313*4882a593Smuzhiyun */
314*4882a593Smuzhiyun if (test_cases[i].skip)
315*4882a593Smuzhiyun control_writeln("SKIP");
316*4882a593Smuzhiyun else
317*4882a593Smuzhiyun control_writeln("NEXT");
318*4882a593Smuzhiyun
319*4882a593Smuzhiyun line = control_readln();
320*4882a593Smuzhiyun if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) {
321*4882a593Smuzhiyun
322*4882a593Smuzhiyun printf("skipped\n");
323*4882a593Smuzhiyun
324*4882a593Smuzhiyun free(line);
325*4882a593Smuzhiyun continue;
326*4882a593Smuzhiyun }
327*4882a593Smuzhiyun
328*4882a593Smuzhiyun control_cmpln(line, "NEXT", true);
329*4882a593Smuzhiyun free(line);
330*4882a593Smuzhiyun
331*4882a593Smuzhiyun if (opts->mode == TEST_MODE_CLIENT)
332*4882a593Smuzhiyun run = test_cases[i].run_client;
333*4882a593Smuzhiyun else
334*4882a593Smuzhiyun run = test_cases[i].run_server;
335*4882a593Smuzhiyun
336*4882a593Smuzhiyun if (run)
337*4882a593Smuzhiyun run(opts);
338*4882a593Smuzhiyun
339*4882a593Smuzhiyun printf("ok\n");
340*4882a593Smuzhiyun }
341*4882a593Smuzhiyun }
342*4882a593Smuzhiyun
list_tests(const struct test_case * test_cases)343*4882a593Smuzhiyun void list_tests(const struct test_case *test_cases)
344*4882a593Smuzhiyun {
345*4882a593Smuzhiyun int i;
346*4882a593Smuzhiyun
347*4882a593Smuzhiyun printf("ID\tTest name\n");
348*4882a593Smuzhiyun
349*4882a593Smuzhiyun for (i = 0; test_cases[i].name; i++)
350*4882a593Smuzhiyun printf("%d\t%s\n", i, test_cases[i].name);
351*4882a593Smuzhiyun
352*4882a593Smuzhiyun exit(EXIT_FAILURE);
353*4882a593Smuzhiyun }
354*4882a593Smuzhiyun
skip_test(struct test_case * test_cases,size_t test_cases_len,const char * test_id_str)355*4882a593Smuzhiyun void skip_test(struct test_case *test_cases, size_t test_cases_len,
356*4882a593Smuzhiyun const char *test_id_str)
357*4882a593Smuzhiyun {
358*4882a593Smuzhiyun unsigned long test_id;
359*4882a593Smuzhiyun char *endptr = NULL;
360*4882a593Smuzhiyun
361*4882a593Smuzhiyun errno = 0;
362*4882a593Smuzhiyun test_id = strtoul(test_id_str, &endptr, 10);
363*4882a593Smuzhiyun if (errno || *endptr != '\0') {
364*4882a593Smuzhiyun fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str);
365*4882a593Smuzhiyun exit(EXIT_FAILURE);
366*4882a593Smuzhiyun }
367*4882a593Smuzhiyun
368*4882a593Smuzhiyun if (test_id >= test_cases_len) {
369*4882a593Smuzhiyun fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n",
370*4882a593Smuzhiyun test_id, test_cases_len - 1);
371*4882a593Smuzhiyun exit(EXIT_FAILURE);
372*4882a593Smuzhiyun }
373*4882a593Smuzhiyun
374*4882a593Smuzhiyun test_cases[test_id].skip = true;
375*4882a593Smuzhiyun }
376