1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun #include "comm.h"
3*4882a593Smuzhiyun #include <errno.h>
4*4882a593Smuzhiyun #include <stdlib.h>
5*4882a593Smuzhiyun #include <stdio.h>
6*4882a593Smuzhiyun #include <string.h>
7*4882a593Smuzhiyun #include <linux/refcount.h>
8*4882a593Smuzhiyun #include <linux/rbtree.h>
9*4882a593Smuzhiyun #include <linux/zalloc.h>
10*4882a593Smuzhiyun #include "rwsem.h"
11*4882a593Smuzhiyun
12*4882a593Smuzhiyun struct comm_str {
13*4882a593Smuzhiyun char *str;
14*4882a593Smuzhiyun struct rb_node rb_node;
15*4882a593Smuzhiyun refcount_t refcnt;
16*4882a593Smuzhiyun };
17*4882a593Smuzhiyun
18*4882a593Smuzhiyun /* Should perhaps be moved to struct machine */
19*4882a593Smuzhiyun static struct rb_root comm_str_root;
20*4882a593Smuzhiyun static struct rw_semaphore comm_str_lock = {.lock = PTHREAD_RWLOCK_INITIALIZER,};
21*4882a593Smuzhiyun
comm_str__get(struct comm_str * cs)22*4882a593Smuzhiyun static struct comm_str *comm_str__get(struct comm_str *cs)
23*4882a593Smuzhiyun {
24*4882a593Smuzhiyun if (cs && refcount_inc_not_zero(&cs->refcnt))
25*4882a593Smuzhiyun return cs;
26*4882a593Smuzhiyun
27*4882a593Smuzhiyun return NULL;
28*4882a593Smuzhiyun }
29*4882a593Smuzhiyun
comm_str__put(struct comm_str * cs)30*4882a593Smuzhiyun static void comm_str__put(struct comm_str *cs)
31*4882a593Smuzhiyun {
32*4882a593Smuzhiyun if (cs && refcount_dec_and_test(&cs->refcnt)) {
33*4882a593Smuzhiyun down_write(&comm_str_lock);
34*4882a593Smuzhiyun rb_erase(&cs->rb_node, &comm_str_root);
35*4882a593Smuzhiyun up_write(&comm_str_lock);
36*4882a593Smuzhiyun zfree(&cs->str);
37*4882a593Smuzhiyun free(cs);
38*4882a593Smuzhiyun }
39*4882a593Smuzhiyun }
40*4882a593Smuzhiyun
comm_str__alloc(const char * str)41*4882a593Smuzhiyun static struct comm_str *comm_str__alloc(const char *str)
42*4882a593Smuzhiyun {
43*4882a593Smuzhiyun struct comm_str *cs;
44*4882a593Smuzhiyun
45*4882a593Smuzhiyun cs = zalloc(sizeof(*cs));
46*4882a593Smuzhiyun if (!cs)
47*4882a593Smuzhiyun return NULL;
48*4882a593Smuzhiyun
49*4882a593Smuzhiyun cs->str = strdup(str);
50*4882a593Smuzhiyun if (!cs->str) {
51*4882a593Smuzhiyun free(cs);
52*4882a593Smuzhiyun return NULL;
53*4882a593Smuzhiyun }
54*4882a593Smuzhiyun
55*4882a593Smuzhiyun refcount_set(&cs->refcnt, 1);
56*4882a593Smuzhiyun
57*4882a593Smuzhiyun return cs;
58*4882a593Smuzhiyun }
59*4882a593Smuzhiyun
60*4882a593Smuzhiyun static
__comm_str__findnew(const char * str,struct rb_root * root)61*4882a593Smuzhiyun struct comm_str *__comm_str__findnew(const char *str, struct rb_root *root)
62*4882a593Smuzhiyun {
63*4882a593Smuzhiyun struct rb_node **p = &root->rb_node;
64*4882a593Smuzhiyun struct rb_node *parent = NULL;
65*4882a593Smuzhiyun struct comm_str *iter, *new;
66*4882a593Smuzhiyun int cmp;
67*4882a593Smuzhiyun
68*4882a593Smuzhiyun while (*p != NULL) {
69*4882a593Smuzhiyun parent = *p;
70*4882a593Smuzhiyun iter = rb_entry(parent, struct comm_str, rb_node);
71*4882a593Smuzhiyun
72*4882a593Smuzhiyun /*
73*4882a593Smuzhiyun * If we race with comm_str__put, iter->refcnt is 0
74*4882a593Smuzhiyun * and it will be removed within comm_str__put call
75*4882a593Smuzhiyun * shortly, ignore it in this search.
76*4882a593Smuzhiyun */
77*4882a593Smuzhiyun cmp = strcmp(str, iter->str);
78*4882a593Smuzhiyun if (!cmp && comm_str__get(iter))
79*4882a593Smuzhiyun return iter;
80*4882a593Smuzhiyun
81*4882a593Smuzhiyun if (cmp < 0)
82*4882a593Smuzhiyun p = &(*p)->rb_left;
83*4882a593Smuzhiyun else
84*4882a593Smuzhiyun p = &(*p)->rb_right;
85*4882a593Smuzhiyun }
86*4882a593Smuzhiyun
87*4882a593Smuzhiyun new = comm_str__alloc(str);
88*4882a593Smuzhiyun if (!new)
89*4882a593Smuzhiyun return NULL;
90*4882a593Smuzhiyun
91*4882a593Smuzhiyun rb_link_node(&new->rb_node, parent, p);
92*4882a593Smuzhiyun rb_insert_color(&new->rb_node, root);
93*4882a593Smuzhiyun
94*4882a593Smuzhiyun return new;
95*4882a593Smuzhiyun }
96*4882a593Smuzhiyun
comm_str__findnew(const char * str,struct rb_root * root)97*4882a593Smuzhiyun static struct comm_str *comm_str__findnew(const char *str, struct rb_root *root)
98*4882a593Smuzhiyun {
99*4882a593Smuzhiyun struct comm_str *cs;
100*4882a593Smuzhiyun
101*4882a593Smuzhiyun down_write(&comm_str_lock);
102*4882a593Smuzhiyun cs = __comm_str__findnew(str, root);
103*4882a593Smuzhiyun up_write(&comm_str_lock);
104*4882a593Smuzhiyun
105*4882a593Smuzhiyun return cs;
106*4882a593Smuzhiyun }
107*4882a593Smuzhiyun
comm__new(const char * str,u64 timestamp,bool exec)108*4882a593Smuzhiyun struct comm *comm__new(const char *str, u64 timestamp, bool exec)
109*4882a593Smuzhiyun {
110*4882a593Smuzhiyun struct comm *comm = zalloc(sizeof(*comm));
111*4882a593Smuzhiyun
112*4882a593Smuzhiyun if (!comm)
113*4882a593Smuzhiyun return NULL;
114*4882a593Smuzhiyun
115*4882a593Smuzhiyun comm->start = timestamp;
116*4882a593Smuzhiyun comm->exec = exec;
117*4882a593Smuzhiyun
118*4882a593Smuzhiyun comm->comm_str = comm_str__findnew(str, &comm_str_root);
119*4882a593Smuzhiyun if (!comm->comm_str) {
120*4882a593Smuzhiyun free(comm);
121*4882a593Smuzhiyun return NULL;
122*4882a593Smuzhiyun }
123*4882a593Smuzhiyun
124*4882a593Smuzhiyun return comm;
125*4882a593Smuzhiyun }
126*4882a593Smuzhiyun
comm__override(struct comm * comm,const char * str,u64 timestamp,bool exec)127*4882a593Smuzhiyun int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec)
128*4882a593Smuzhiyun {
129*4882a593Smuzhiyun struct comm_str *new, *old = comm->comm_str;
130*4882a593Smuzhiyun
131*4882a593Smuzhiyun new = comm_str__findnew(str, &comm_str_root);
132*4882a593Smuzhiyun if (!new)
133*4882a593Smuzhiyun return -ENOMEM;
134*4882a593Smuzhiyun
135*4882a593Smuzhiyun comm_str__put(old);
136*4882a593Smuzhiyun comm->comm_str = new;
137*4882a593Smuzhiyun comm->start = timestamp;
138*4882a593Smuzhiyun if (exec)
139*4882a593Smuzhiyun comm->exec = true;
140*4882a593Smuzhiyun
141*4882a593Smuzhiyun return 0;
142*4882a593Smuzhiyun }
143*4882a593Smuzhiyun
comm__free(struct comm * comm)144*4882a593Smuzhiyun void comm__free(struct comm *comm)
145*4882a593Smuzhiyun {
146*4882a593Smuzhiyun comm_str__put(comm->comm_str);
147*4882a593Smuzhiyun free(comm);
148*4882a593Smuzhiyun }
149*4882a593Smuzhiyun
comm__str(const struct comm * comm)150*4882a593Smuzhiyun const char *comm__str(const struct comm *comm)
151*4882a593Smuzhiyun {
152*4882a593Smuzhiyun return comm->comm_str->str;
153*4882a593Smuzhiyun }
154