1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun // Copyright (C) 2019-2020 Arm Ltd.
3*4882a593Smuzhiyun
4*4882a593Smuzhiyun #include <linux/compiler.h>
5*4882a593Smuzhiyun #include <linux/kasan-checks.h>
6*4882a593Smuzhiyun #include <linux/kernel.h>
7*4882a593Smuzhiyun
8*4882a593Smuzhiyun #include <net/checksum.h>
9*4882a593Smuzhiyun
10*4882a593Smuzhiyun /* Looks dumb, but generates nice-ish code */
accumulate(u64 sum,u64 data)11*4882a593Smuzhiyun static u64 accumulate(u64 sum, u64 data)
12*4882a593Smuzhiyun {
13*4882a593Smuzhiyun __uint128_t tmp = (__uint128_t)sum + data;
14*4882a593Smuzhiyun return tmp + (tmp >> 64);
15*4882a593Smuzhiyun }
16*4882a593Smuzhiyun
17*4882a593Smuzhiyun /*
18*4882a593Smuzhiyun * We over-read the buffer and this makes KASAN unhappy. Instead, disable
19*4882a593Smuzhiyun * instrumentation and call kasan explicitly.
20*4882a593Smuzhiyun */
do_csum(const unsigned char * buff,int len)21*4882a593Smuzhiyun unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
22*4882a593Smuzhiyun {
23*4882a593Smuzhiyun unsigned int offset, shift, sum;
24*4882a593Smuzhiyun const u64 *ptr;
25*4882a593Smuzhiyun u64 data, sum64 = 0;
26*4882a593Smuzhiyun
27*4882a593Smuzhiyun if (unlikely(len == 0))
28*4882a593Smuzhiyun return 0;
29*4882a593Smuzhiyun
30*4882a593Smuzhiyun offset = (unsigned long)buff & 7;
31*4882a593Smuzhiyun /*
32*4882a593Smuzhiyun * This is to all intents and purposes safe, since rounding down cannot
33*4882a593Smuzhiyun * result in a different page or cache line being accessed, and @buff
34*4882a593Smuzhiyun * should absolutely not be pointing to anything read-sensitive. We do,
35*4882a593Smuzhiyun * however, have to be careful not to piss off KASAN, which means using
36*4882a593Smuzhiyun * unchecked reads to accommodate the head and tail, for which we'll
37*4882a593Smuzhiyun * compensate with an explicit check up-front.
38*4882a593Smuzhiyun */
39*4882a593Smuzhiyun kasan_check_read(buff, len);
40*4882a593Smuzhiyun ptr = (u64 *)(buff - offset);
41*4882a593Smuzhiyun len = len + offset - 8;
42*4882a593Smuzhiyun
43*4882a593Smuzhiyun /*
44*4882a593Smuzhiyun * Head: zero out any excess leading bytes. Shifting back by the same
45*4882a593Smuzhiyun * amount should be at least as fast as any other way of handling the
46*4882a593Smuzhiyun * odd/even alignment, and means we can ignore it until the very end.
47*4882a593Smuzhiyun */
48*4882a593Smuzhiyun shift = offset * 8;
49*4882a593Smuzhiyun data = *ptr++;
50*4882a593Smuzhiyun #ifdef __LITTLE_ENDIAN
51*4882a593Smuzhiyun data = (data >> shift) << shift;
52*4882a593Smuzhiyun #else
53*4882a593Smuzhiyun data = (data << shift) >> shift;
54*4882a593Smuzhiyun #endif
55*4882a593Smuzhiyun
56*4882a593Smuzhiyun /*
57*4882a593Smuzhiyun * Body: straightforward aligned loads from here on (the paired loads
58*4882a593Smuzhiyun * underlying the quadword type still only need dword alignment). The
59*4882a593Smuzhiyun * main loop strictly excludes the tail, so the second loop will always
60*4882a593Smuzhiyun * run at least once.
61*4882a593Smuzhiyun */
62*4882a593Smuzhiyun while (unlikely(len > 64)) {
63*4882a593Smuzhiyun __uint128_t tmp1, tmp2, tmp3, tmp4;
64*4882a593Smuzhiyun
65*4882a593Smuzhiyun tmp1 = *(__uint128_t *)ptr;
66*4882a593Smuzhiyun tmp2 = *(__uint128_t *)(ptr + 2);
67*4882a593Smuzhiyun tmp3 = *(__uint128_t *)(ptr + 4);
68*4882a593Smuzhiyun tmp4 = *(__uint128_t *)(ptr + 6);
69*4882a593Smuzhiyun
70*4882a593Smuzhiyun len -= 64;
71*4882a593Smuzhiyun ptr += 8;
72*4882a593Smuzhiyun
73*4882a593Smuzhiyun /* This is the "don't dump the carry flag into a GPR" idiom */
74*4882a593Smuzhiyun tmp1 += (tmp1 >> 64) | (tmp1 << 64);
75*4882a593Smuzhiyun tmp2 += (tmp2 >> 64) | (tmp2 << 64);
76*4882a593Smuzhiyun tmp3 += (tmp3 >> 64) | (tmp3 << 64);
77*4882a593Smuzhiyun tmp4 += (tmp4 >> 64) | (tmp4 << 64);
78*4882a593Smuzhiyun tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
79*4882a593Smuzhiyun tmp1 += (tmp1 >> 64) | (tmp1 << 64);
80*4882a593Smuzhiyun tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
81*4882a593Smuzhiyun tmp3 += (tmp3 >> 64) | (tmp3 << 64);
82*4882a593Smuzhiyun tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
83*4882a593Smuzhiyun tmp1 += (tmp1 >> 64) | (tmp1 << 64);
84*4882a593Smuzhiyun tmp1 = ((tmp1 >> 64) << 64) | sum64;
85*4882a593Smuzhiyun tmp1 += (tmp1 >> 64) | (tmp1 << 64);
86*4882a593Smuzhiyun sum64 = tmp1 >> 64;
87*4882a593Smuzhiyun }
88*4882a593Smuzhiyun while (len > 8) {
89*4882a593Smuzhiyun __uint128_t tmp;
90*4882a593Smuzhiyun
91*4882a593Smuzhiyun sum64 = accumulate(sum64, data);
92*4882a593Smuzhiyun tmp = *(__uint128_t *)ptr;
93*4882a593Smuzhiyun
94*4882a593Smuzhiyun len -= 16;
95*4882a593Smuzhiyun ptr += 2;
96*4882a593Smuzhiyun
97*4882a593Smuzhiyun #ifdef __LITTLE_ENDIAN
98*4882a593Smuzhiyun data = tmp >> 64;
99*4882a593Smuzhiyun sum64 = accumulate(sum64, tmp);
100*4882a593Smuzhiyun #else
101*4882a593Smuzhiyun data = tmp;
102*4882a593Smuzhiyun sum64 = accumulate(sum64, tmp >> 64);
103*4882a593Smuzhiyun #endif
104*4882a593Smuzhiyun }
105*4882a593Smuzhiyun if (len > 0) {
106*4882a593Smuzhiyun sum64 = accumulate(sum64, data);
107*4882a593Smuzhiyun data = *ptr;
108*4882a593Smuzhiyun len -= 8;
109*4882a593Smuzhiyun }
110*4882a593Smuzhiyun /*
111*4882a593Smuzhiyun * Tail: zero any over-read bytes similarly to the head, again
112*4882a593Smuzhiyun * preserving odd/even alignment.
113*4882a593Smuzhiyun */
114*4882a593Smuzhiyun shift = len * -8;
115*4882a593Smuzhiyun #ifdef __LITTLE_ENDIAN
116*4882a593Smuzhiyun data = (data << shift) >> shift;
117*4882a593Smuzhiyun #else
118*4882a593Smuzhiyun data = (data >> shift) << shift;
119*4882a593Smuzhiyun #endif
120*4882a593Smuzhiyun sum64 = accumulate(sum64, data);
121*4882a593Smuzhiyun
122*4882a593Smuzhiyun /* Finally, folding */
123*4882a593Smuzhiyun sum64 += (sum64 >> 32) | (sum64 << 32);
124*4882a593Smuzhiyun sum = sum64 >> 32;
125*4882a593Smuzhiyun sum += (sum >> 16) | (sum << 16);
126*4882a593Smuzhiyun if (offset & 1)
127*4882a593Smuzhiyun return (u16)swab32(sum);
128*4882a593Smuzhiyun
129*4882a593Smuzhiyun return sum >> 16;
130*4882a593Smuzhiyun }
131*4882a593Smuzhiyun
csum_ipv6_magic(const struct in6_addr * saddr,const struct in6_addr * daddr,__u32 len,__u8 proto,__wsum csum)132*4882a593Smuzhiyun __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
133*4882a593Smuzhiyun const struct in6_addr *daddr,
134*4882a593Smuzhiyun __u32 len, __u8 proto, __wsum csum)
135*4882a593Smuzhiyun {
136*4882a593Smuzhiyun __uint128_t src, dst;
137*4882a593Smuzhiyun u64 sum = (__force u64)csum;
138*4882a593Smuzhiyun
139*4882a593Smuzhiyun src = *(const __uint128_t *)saddr->s6_addr;
140*4882a593Smuzhiyun dst = *(const __uint128_t *)daddr->s6_addr;
141*4882a593Smuzhiyun
142*4882a593Smuzhiyun sum += (__force u32)htonl(len);
143*4882a593Smuzhiyun #ifdef __LITTLE_ENDIAN
144*4882a593Smuzhiyun sum += (u32)proto << 24;
145*4882a593Smuzhiyun #else
146*4882a593Smuzhiyun sum += proto;
147*4882a593Smuzhiyun #endif
148*4882a593Smuzhiyun src += (src >> 64) | (src << 64);
149*4882a593Smuzhiyun dst += (dst >> 64) | (dst << 64);
150*4882a593Smuzhiyun
151*4882a593Smuzhiyun sum = accumulate(sum, src >> 64);
152*4882a593Smuzhiyun sum = accumulate(sum, dst >> 64);
153*4882a593Smuzhiyun
154*4882a593Smuzhiyun sum += ((sum >> 32) | (sum << 32));
155*4882a593Smuzhiyun return csum_fold((__force __wsum)(sum >> 32));
156*4882a593Smuzhiyun }
157*4882a593Smuzhiyun EXPORT_SYMBOL(csum_ipv6_magic);
158