1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun
3*4882a593Smuzhiyun #include <sys/socket.h>
4*4882a593Smuzhiyun #include <linux/bpf.h>
5*4882a593Smuzhiyun #include <bpf/bpf_helpers.h>
6*4882a593Smuzhiyun
7*4882a593Smuzhiyun int invocations = 0, in_use = 0;
8*4882a593Smuzhiyun
9*4882a593Smuzhiyun struct {
10*4882a593Smuzhiyun __uint(type, BPF_MAP_TYPE_SK_STORAGE);
11*4882a593Smuzhiyun __uint(map_flags, BPF_F_NO_PREALLOC);
12*4882a593Smuzhiyun __type(key, int);
13*4882a593Smuzhiyun __type(value, int);
14*4882a593Smuzhiyun } sk_map SEC(".maps");
15*4882a593Smuzhiyun
16*4882a593Smuzhiyun SEC("cgroup/sock_create")
sock(struct bpf_sock * ctx)17*4882a593Smuzhiyun int sock(struct bpf_sock *ctx)
18*4882a593Smuzhiyun {
19*4882a593Smuzhiyun int *sk_storage;
20*4882a593Smuzhiyun __u32 key;
21*4882a593Smuzhiyun
22*4882a593Smuzhiyun if (ctx->type != SOCK_DGRAM)
23*4882a593Smuzhiyun return 1;
24*4882a593Smuzhiyun
25*4882a593Smuzhiyun sk_storage = bpf_sk_storage_get(&sk_map, ctx, 0,
26*4882a593Smuzhiyun BPF_SK_STORAGE_GET_F_CREATE);
27*4882a593Smuzhiyun if (!sk_storage)
28*4882a593Smuzhiyun return 0;
29*4882a593Smuzhiyun *sk_storage = 0xdeadbeef;
30*4882a593Smuzhiyun
31*4882a593Smuzhiyun __sync_fetch_and_add(&invocations, 1);
32*4882a593Smuzhiyun
33*4882a593Smuzhiyun if (in_use > 0) {
34*4882a593Smuzhiyun /* BPF_CGROUP_INET_SOCK_RELEASE is _not_ called
35*4882a593Smuzhiyun * when we return an error from the BPF
36*4882a593Smuzhiyun * program!
37*4882a593Smuzhiyun */
38*4882a593Smuzhiyun return 0;
39*4882a593Smuzhiyun }
40*4882a593Smuzhiyun
41*4882a593Smuzhiyun __sync_fetch_and_add(&in_use, 1);
42*4882a593Smuzhiyun return 1;
43*4882a593Smuzhiyun }
44*4882a593Smuzhiyun
45*4882a593Smuzhiyun SEC("cgroup/sock_release")
sock_release(struct bpf_sock * ctx)46*4882a593Smuzhiyun int sock_release(struct bpf_sock *ctx)
47*4882a593Smuzhiyun {
48*4882a593Smuzhiyun int *sk_storage;
49*4882a593Smuzhiyun __u32 key;
50*4882a593Smuzhiyun
51*4882a593Smuzhiyun if (ctx->type != SOCK_DGRAM)
52*4882a593Smuzhiyun return 1;
53*4882a593Smuzhiyun
54*4882a593Smuzhiyun sk_storage = bpf_sk_storage_get(&sk_map, ctx, 0, 0);
55*4882a593Smuzhiyun if (!sk_storage || *sk_storage != 0xdeadbeef)
56*4882a593Smuzhiyun return 0;
57*4882a593Smuzhiyun
58*4882a593Smuzhiyun __sync_fetch_and_add(&invocations, 1);
59*4882a593Smuzhiyun __sync_fetch_and_add(&in_use, -1);
60*4882a593Smuzhiyun return 1;
61*4882a593Smuzhiyun }
62