1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun #include <net/ip.h>
3*4882a593Smuzhiyun #include <net/udp.h>
4*4882a593Smuzhiyun #include <net/udplite.h>
5*4882a593Smuzhiyun #include <asm/checksum.h>
6*4882a593Smuzhiyun
7*4882a593Smuzhiyun #ifndef _HAVE_ARCH_IPV6_CSUM
csum_ipv6_magic(const struct in6_addr * saddr,const struct in6_addr * daddr,__u32 len,__u8 proto,__wsum csum)8*4882a593Smuzhiyun __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
9*4882a593Smuzhiyun const struct in6_addr *daddr,
10*4882a593Smuzhiyun __u32 len, __u8 proto, __wsum csum)
11*4882a593Smuzhiyun {
12*4882a593Smuzhiyun
13*4882a593Smuzhiyun int carry;
14*4882a593Smuzhiyun __u32 ulen;
15*4882a593Smuzhiyun __u32 uproto;
16*4882a593Smuzhiyun __u32 sum = (__force u32)csum;
17*4882a593Smuzhiyun
18*4882a593Smuzhiyun sum += (__force u32)saddr->s6_addr32[0];
19*4882a593Smuzhiyun carry = (sum < (__force u32)saddr->s6_addr32[0]);
20*4882a593Smuzhiyun sum += carry;
21*4882a593Smuzhiyun
22*4882a593Smuzhiyun sum += (__force u32)saddr->s6_addr32[1];
23*4882a593Smuzhiyun carry = (sum < (__force u32)saddr->s6_addr32[1]);
24*4882a593Smuzhiyun sum += carry;
25*4882a593Smuzhiyun
26*4882a593Smuzhiyun sum += (__force u32)saddr->s6_addr32[2];
27*4882a593Smuzhiyun carry = (sum < (__force u32)saddr->s6_addr32[2]);
28*4882a593Smuzhiyun sum += carry;
29*4882a593Smuzhiyun
30*4882a593Smuzhiyun sum += (__force u32)saddr->s6_addr32[3];
31*4882a593Smuzhiyun carry = (sum < (__force u32)saddr->s6_addr32[3]);
32*4882a593Smuzhiyun sum += carry;
33*4882a593Smuzhiyun
34*4882a593Smuzhiyun sum += (__force u32)daddr->s6_addr32[0];
35*4882a593Smuzhiyun carry = (sum < (__force u32)daddr->s6_addr32[0]);
36*4882a593Smuzhiyun sum += carry;
37*4882a593Smuzhiyun
38*4882a593Smuzhiyun sum += (__force u32)daddr->s6_addr32[1];
39*4882a593Smuzhiyun carry = (sum < (__force u32)daddr->s6_addr32[1]);
40*4882a593Smuzhiyun sum += carry;
41*4882a593Smuzhiyun
42*4882a593Smuzhiyun sum += (__force u32)daddr->s6_addr32[2];
43*4882a593Smuzhiyun carry = (sum < (__force u32)daddr->s6_addr32[2]);
44*4882a593Smuzhiyun sum += carry;
45*4882a593Smuzhiyun
46*4882a593Smuzhiyun sum += (__force u32)daddr->s6_addr32[3];
47*4882a593Smuzhiyun carry = (sum < (__force u32)daddr->s6_addr32[3]);
48*4882a593Smuzhiyun sum += carry;
49*4882a593Smuzhiyun
50*4882a593Smuzhiyun ulen = (__force u32)htonl((__u32) len);
51*4882a593Smuzhiyun sum += ulen;
52*4882a593Smuzhiyun carry = (sum < ulen);
53*4882a593Smuzhiyun sum += carry;
54*4882a593Smuzhiyun
55*4882a593Smuzhiyun uproto = (__force u32)htonl(proto);
56*4882a593Smuzhiyun sum += uproto;
57*4882a593Smuzhiyun carry = (sum < uproto);
58*4882a593Smuzhiyun sum += carry;
59*4882a593Smuzhiyun
60*4882a593Smuzhiyun return csum_fold((__force __wsum)sum);
61*4882a593Smuzhiyun }
62*4882a593Smuzhiyun EXPORT_SYMBOL(csum_ipv6_magic);
63*4882a593Smuzhiyun #endif
64*4882a593Smuzhiyun
udp6_csum_init(struct sk_buff * skb,struct udphdr * uh,int proto)65*4882a593Smuzhiyun int udp6_csum_init(struct sk_buff *skb, struct udphdr *uh, int proto)
66*4882a593Smuzhiyun {
67*4882a593Smuzhiyun int err;
68*4882a593Smuzhiyun
69*4882a593Smuzhiyun UDP_SKB_CB(skb)->partial_cov = 0;
70*4882a593Smuzhiyun UDP_SKB_CB(skb)->cscov = skb->len;
71*4882a593Smuzhiyun
72*4882a593Smuzhiyun if (proto == IPPROTO_UDPLITE) {
73*4882a593Smuzhiyun err = udplite_checksum_init(skb, uh);
74*4882a593Smuzhiyun if (err)
75*4882a593Smuzhiyun return err;
76*4882a593Smuzhiyun
77*4882a593Smuzhiyun if (UDP_SKB_CB(skb)->partial_cov) {
78*4882a593Smuzhiyun skb->csum = ip6_compute_pseudo(skb, proto);
79*4882a593Smuzhiyun return 0;
80*4882a593Smuzhiyun }
81*4882a593Smuzhiyun }
82*4882a593Smuzhiyun
83*4882a593Smuzhiyun /* To support RFC 6936 (allow zero checksum in UDP/IPV6 for tunnels)
84*4882a593Smuzhiyun * we accept a checksum of zero here. When we find the socket
85*4882a593Smuzhiyun * for the UDP packet we'll check if that socket allows zero checksum
86*4882a593Smuzhiyun * for IPv6 (set by socket option).
87*4882a593Smuzhiyun *
88*4882a593Smuzhiyun * Note, we are only interested in != 0 or == 0, thus the
89*4882a593Smuzhiyun * force to int.
90*4882a593Smuzhiyun */
91*4882a593Smuzhiyun err = (__force int)skb_checksum_init_zero_check(skb, proto, uh->check,
92*4882a593Smuzhiyun ip6_compute_pseudo);
93*4882a593Smuzhiyun if (err)
94*4882a593Smuzhiyun return err;
95*4882a593Smuzhiyun
96*4882a593Smuzhiyun if (skb->ip_summed == CHECKSUM_COMPLETE && !skb->csum_valid) {
97*4882a593Smuzhiyun /* If SW calculated the value, we know it's bad */
98*4882a593Smuzhiyun if (skb->csum_complete_sw)
99*4882a593Smuzhiyun return 1;
100*4882a593Smuzhiyun
101*4882a593Smuzhiyun /* HW says the value is bad. Let's validate that.
102*4882a593Smuzhiyun * skb->csum is no longer the full packet checksum,
103*4882a593Smuzhiyun * so don't treat is as such.
104*4882a593Smuzhiyun */
105*4882a593Smuzhiyun skb_checksum_complete_unset(skb);
106*4882a593Smuzhiyun }
107*4882a593Smuzhiyun
108*4882a593Smuzhiyun return 0;
109*4882a593Smuzhiyun }
110*4882a593Smuzhiyun EXPORT_SYMBOL(udp6_csum_init);
111*4882a593Smuzhiyun
112*4882a593Smuzhiyun /* Function to set UDP checksum for an IPv6 UDP packet. This is intended
113*4882a593Smuzhiyun * for the simple case like when setting the checksum for a UDP tunnel.
114*4882a593Smuzhiyun */
udp6_set_csum(bool nocheck,struct sk_buff * skb,const struct in6_addr * saddr,const struct in6_addr * daddr,int len)115*4882a593Smuzhiyun void udp6_set_csum(bool nocheck, struct sk_buff *skb,
116*4882a593Smuzhiyun const struct in6_addr *saddr,
117*4882a593Smuzhiyun const struct in6_addr *daddr, int len)
118*4882a593Smuzhiyun {
119*4882a593Smuzhiyun struct udphdr *uh = udp_hdr(skb);
120*4882a593Smuzhiyun
121*4882a593Smuzhiyun if (nocheck)
122*4882a593Smuzhiyun uh->check = 0;
123*4882a593Smuzhiyun else if (skb_is_gso(skb))
124*4882a593Smuzhiyun uh->check = ~udp_v6_check(len, saddr, daddr, 0);
125*4882a593Smuzhiyun else if (skb->ip_summed == CHECKSUM_PARTIAL) {
126*4882a593Smuzhiyun uh->check = 0;
127*4882a593Smuzhiyun uh->check = udp_v6_check(len, saddr, daddr, lco_csum(skb));
128*4882a593Smuzhiyun if (uh->check == 0)
129*4882a593Smuzhiyun uh->check = CSUM_MANGLED_0;
130*4882a593Smuzhiyun } else {
131*4882a593Smuzhiyun skb->ip_summed = CHECKSUM_PARTIAL;
132*4882a593Smuzhiyun skb->csum_start = skb_transport_header(skb) - skb->head;
133*4882a593Smuzhiyun skb->csum_offset = offsetof(struct udphdr, check);
134*4882a593Smuzhiyun uh->check = ~udp_v6_check(len, saddr, daddr, 0);
135*4882a593Smuzhiyun }
136*4882a593Smuzhiyun }
137*4882a593Smuzhiyun EXPORT_SYMBOL(udp6_set_csum);
138