xref: /OK3568_Linux_fs/kernel/drivers/infiniband/core/multicast.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun /*
2*4882a593Smuzhiyun  * Copyright (c) 2006 Intel Corporation.  All rights reserved.
3*4882a593Smuzhiyun  *
4*4882a593Smuzhiyun  * This software is available to you under a choice of one of two
5*4882a593Smuzhiyun  * licenses.  You may choose to be licensed under the terms of the GNU
6*4882a593Smuzhiyun  * General Public License (GPL) Version 2, available from the file
7*4882a593Smuzhiyun  * COPYING in the main directory of this source tree, or the
8*4882a593Smuzhiyun  * OpenIB.org BSD license below:
9*4882a593Smuzhiyun  *
10*4882a593Smuzhiyun  *     Redistribution and use in source and binary forms, with or
11*4882a593Smuzhiyun  *     without modification, are permitted provided that the following
12*4882a593Smuzhiyun  *     conditions are met:
13*4882a593Smuzhiyun  *
14*4882a593Smuzhiyun  *      - Redistributions of source code must retain the above
15*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
16*4882a593Smuzhiyun  *        disclaimer.
17*4882a593Smuzhiyun  *
18*4882a593Smuzhiyun  *      - Redistributions in binary form must reproduce the above
19*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
20*4882a593Smuzhiyun  *        disclaimer in the documentation and/or other materials
21*4882a593Smuzhiyun  *        provided with the distribution.
22*4882a593Smuzhiyun  *
23*4882a593Smuzhiyun  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24*4882a593Smuzhiyun  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25*4882a593Smuzhiyun  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26*4882a593Smuzhiyun  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27*4882a593Smuzhiyun  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28*4882a593Smuzhiyun  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29*4882a593Smuzhiyun  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30*4882a593Smuzhiyun  * SOFTWARE.
31*4882a593Smuzhiyun  */
32*4882a593Smuzhiyun 
33*4882a593Smuzhiyun #include <linux/completion.h>
34*4882a593Smuzhiyun #include <linux/dma-mapping.h>
35*4882a593Smuzhiyun #include <linux/err.h>
36*4882a593Smuzhiyun #include <linux/interrupt.h>
37*4882a593Smuzhiyun #include <linux/export.h>
38*4882a593Smuzhiyun #include <linux/slab.h>
39*4882a593Smuzhiyun #include <linux/bitops.h>
40*4882a593Smuzhiyun #include <linux/random.h>
41*4882a593Smuzhiyun 
42*4882a593Smuzhiyun #include <rdma/ib_cache.h>
43*4882a593Smuzhiyun #include "sa.h"
44*4882a593Smuzhiyun 
45*4882a593Smuzhiyun static int mcast_add_one(struct ib_device *device);
46*4882a593Smuzhiyun static void mcast_remove_one(struct ib_device *device, void *client_data);
47*4882a593Smuzhiyun 
48*4882a593Smuzhiyun static struct ib_client mcast_client = {
49*4882a593Smuzhiyun 	.name   = "ib_multicast",
50*4882a593Smuzhiyun 	.add    = mcast_add_one,
51*4882a593Smuzhiyun 	.remove = mcast_remove_one
52*4882a593Smuzhiyun };
53*4882a593Smuzhiyun 
54*4882a593Smuzhiyun static struct ib_sa_client	sa_client;
55*4882a593Smuzhiyun static struct workqueue_struct	*mcast_wq;
56*4882a593Smuzhiyun static union ib_gid mgid0;
57*4882a593Smuzhiyun 
58*4882a593Smuzhiyun struct mcast_device;
59*4882a593Smuzhiyun 
60*4882a593Smuzhiyun struct mcast_port {
61*4882a593Smuzhiyun 	struct mcast_device	*dev;
62*4882a593Smuzhiyun 	spinlock_t		lock;
63*4882a593Smuzhiyun 	struct rb_root		table;
64*4882a593Smuzhiyun 	atomic_t		refcount;
65*4882a593Smuzhiyun 	struct completion	comp;
66*4882a593Smuzhiyun 	u8			port_num;
67*4882a593Smuzhiyun };
68*4882a593Smuzhiyun 
69*4882a593Smuzhiyun struct mcast_device {
70*4882a593Smuzhiyun 	struct ib_device	*device;
71*4882a593Smuzhiyun 	struct ib_event_handler	event_handler;
72*4882a593Smuzhiyun 	int			start_port;
73*4882a593Smuzhiyun 	int			end_port;
74*4882a593Smuzhiyun 	struct mcast_port	port[];
75*4882a593Smuzhiyun };
76*4882a593Smuzhiyun 
77*4882a593Smuzhiyun enum mcast_state {
78*4882a593Smuzhiyun 	MCAST_JOINING,
79*4882a593Smuzhiyun 	MCAST_MEMBER,
80*4882a593Smuzhiyun 	MCAST_ERROR,
81*4882a593Smuzhiyun };
82*4882a593Smuzhiyun 
83*4882a593Smuzhiyun enum mcast_group_state {
84*4882a593Smuzhiyun 	MCAST_IDLE,
85*4882a593Smuzhiyun 	MCAST_BUSY,
86*4882a593Smuzhiyun 	MCAST_GROUP_ERROR,
87*4882a593Smuzhiyun 	MCAST_PKEY_EVENT
88*4882a593Smuzhiyun };
89*4882a593Smuzhiyun 
90*4882a593Smuzhiyun enum {
91*4882a593Smuzhiyun 	MCAST_INVALID_PKEY_INDEX = 0xFFFF
92*4882a593Smuzhiyun };
93*4882a593Smuzhiyun 
94*4882a593Smuzhiyun struct mcast_member;
95*4882a593Smuzhiyun 
96*4882a593Smuzhiyun struct mcast_group {
97*4882a593Smuzhiyun 	struct ib_sa_mcmember_rec rec;
98*4882a593Smuzhiyun 	struct rb_node		node;
99*4882a593Smuzhiyun 	struct mcast_port	*port;
100*4882a593Smuzhiyun 	spinlock_t		lock;
101*4882a593Smuzhiyun 	struct work_struct	work;
102*4882a593Smuzhiyun 	struct list_head	pending_list;
103*4882a593Smuzhiyun 	struct list_head	active_list;
104*4882a593Smuzhiyun 	struct mcast_member	*last_join;
105*4882a593Smuzhiyun 	int			members[NUM_JOIN_MEMBERSHIP_TYPES];
106*4882a593Smuzhiyun 	atomic_t		refcount;
107*4882a593Smuzhiyun 	enum mcast_group_state	state;
108*4882a593Smuzhiyun 	struct ib_sa_query	*query;
109*4882a593Smuzhiyun 	u16			pkey_index;
110*4882a593Smuzhiyun 	u8			leave_state;
111*4882a593Smuzhiyun 	int			retries;
112*4882a593Smuzhiyun };
113*4882a593Smuzhiyun 
114*4882a593Smuzhiyun struct mcast_member {
115*4882a593Smuzhiyun 	struct ib_sa_multicast	multicast;
116*4882a593Smuzhiyun 	struct ib_sa_client	*client;
117*4882a593Smuzhiyun 	struct mcast_group	*group;
118*4882a593Smuzhiyun 	struct list_head	list;
119*4882a593Smuzhiyun 	enum mcast_state	state;
120*4882a593Smuzhiyun 	atomic_t		refcount;
121*4882a593Smuzhiyun 	struct completion	comp;
122*4882a593Smuzhiyun };
123*4882a593Smuzhiyun 
124*4882a593Smuzhiyun static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
125*4882a593Smuzhiyun 			 void *context);
126*4882a593Smuzhiyun static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
127*4882a593Smuzhiyun 			  void *context);
128*4882a593Smuzhiyun 
mcast_find(struct mcast_port * port,union ib_gid * mgid)129*4882a593Smuzhiyun static struct mcast_group *mcast_find(struct mcast_port *port,
130*4882a593Smuzhiyun 				      union ib_gid *mgid)
131*4882a593Smuzhiyun {
132*4882a593Smuzhiyun 	struct rb_node *node = port->table.rb_node;
133*4882a593Smuzhiyun 	struct mcast_group *group;
134*4882a593Smuzhiyun 	int ret;
135*4882a593Smuzhiyun 
136*4882a593Smuzhiyun 	while (node) {
137*4882a593Smuzhiyun 		group = rb_entry(node, struct mcast_group, node);
138*4882a593Smuzhiyun 		ret = memcmp(mgid->raw, group->rec.mgid.raw, sizeof *mgid);
139*4882a593Smuzhiyun 		if (!ret)
140*4882a593Smuzhiyun 			return group;
141*4882a593Smuzhiyun 
142*4882a593Smuzhiyun 		if (ret < 0)
143*4882a593Smuzhiyun 			node = node->rb_left;
144*4882a593Smuzhiyun 		else
145*4882a593Smuzhiyun 			node = node->rb_right;
146*4882a593Smuzhiyun 	}
147*4882a593Smuzhiyun 	return NULL;
148*4882a593Smuzhiyun }
149*4882a593Smuzhiyun 
mcast_insert(struct mcast_port * port,struct mcast_group * group,int allow_duplicates)150*4882a593Smuzhiyun static struct mcast_group *mcast_insert(struct mcast_port *port,
151*4882a593Smuzhiyun 					struct mcast_group *group,
152*4882a593Smuzhiyun 					int allow_duplicates)
153*4882a593Smuzhiyun {
154*4882a593Smuzhiyun 	struct rb_node **link = &port->table.rb_node;
155*4882a593Smuzhiyun 	struct rb_node *parent = NULL;
156*4882a593Smuzhiyun 	struct mcast_group *cur_group;
157*4882a593Smuzhiyun 	int ret;
158*4882a593Smuzhiyun 
159*4882a593Smuzhiyun 	while (*link) {
160*4882a593Smuzhiyun 		parent = *link;
161*4882a593Smuzhiyun 		cur_group = rb_entry(parent, struct mcast_group, node);
162*4882a593Smuzhiyun 
163*4882a593Smuzhiyun 		ret = memcmp(group->rec.mgid.raw, cur_group->rec.mgid.raw,
164*4882a593Smuzhiyun 			     sizeof group->rec.mgid);
165*4882a593Smuzhiyun 		if (ret < 0)
166*4882a593Smuzhiyun 			link = &(*link)->rb_left;
167*4882a593Smuzhiyun 		else if (ret > 0)
168*4882a593Smuzhiyun 			link = &(*link)->rb_right;
169*4882a593Smuzhiyun 		else if (allow_duplicates)
170*4882a593Smuzhiyun 			link = &(*link)->rb_left;
171*4882a593Smuzhiyun 		else
172*4882a593Smuzhiyun 			return cur_group;
173*4882a593Smuzhiyun 	}
174*4882a593Smuzhiyun 	rb_link_node(&group->node, parent, link);
175*4882a593Smuzhiyun 	rb_insert_color(&group->node, &port->table);
176*4882a593Smuzhiyun 	return NULL;
177*4882a593Smuzhiyun }
178*4882a593Smuzhiyun 
deref_port(struct mcast_port * port)179*4882a593Smuzhiyun static void deref_port(struct mcast_port *port)
180*4882a593Smuzhiyun {
181*4882a593Smuzhiyun 	if (atomic_dec_and_test(&port->refcount))
182*4882a593Smuzhiyun 		complete(&port->comp);
183*4882a593Smuzhiyun }
184*4882a593Smuzhiyun 
release_group(struct mcast_group * group)185*4882a593Smuzhiyun static void release_group(struct mcast_group *group)
186*4882a593Smuzhiyun {
187*4882a593Smuzhiyun 	struct mcast_port *port = group->port;
188*4882a593Smuzhiyun 	unsigned long flags;
189*4882a593Smuzhiyun 
190*4882a593Smuzhiyun 	spin_lock_irqsave(&port->lock, flags);
191*4882a593Smuzhiyun 	if (atomic_dec_and_test(&group->refcount)) {
192*4882a593Smuzhiyun 		rb_erase(&group->node, &port->table);
193*4882a593Smuzhiyun 		spin_unlock_irqrestore(&port->lock, flags);
194*4882a593Smuzhiyun 		kfree(group);
195*4882a593Smuzhiyun 		deref_port(port);
196*4882a593Smuzhiyun 	} else
197*4882a593Smuzhiyun 		spin_unlock_irqrestore(&port->lock, flags);
198*4882a593Smuzhiyun }
199*4882a593Smuzhiyun 
deref_member(struct mcast_member * member)200*4882a593Smuzhiyun static void deref_member(struct mcast_member *member)
201*4882a593Smuzhiyun {
202*4882a593Smuzhiyun 	if (atomic_dec_and_test(&member->refcount))
203*4882a593Smuzhiyun 		complete(&member->comp);
204*4882a593Smuzhiyun }
205*4882a593Smuzhiyun 
queue_join(struct mcast_member * member)206*4882a593Smuzhiyun static void queue_join(struct mcast_member *member)
207*4882a593Smuzhiyun {
208*4882a593Smuzhiyun 	struct mcast_group *group = member->group;
209*4882a593Smuzhiyun 	unsigned long flags;
210*4882a593Smuzhiyun 
211*4882a593Smuzhiyun 	spin_lock_irqsave(&group->lock, flags);
212*4882a593Smuzhiyun 	list_add_tail(&member->list, &group->pending_list);
213*4882a593Smuzhiyun 	if (group->state == MCAST_IDLE) {
214*4882a593Smuzhiyun 		group->state = MCAST_BUSY;
215*4882a593Smuzhiyun 		atomic_inc(&group->refcount);
216*4882a593Smuzhiyun 		queue_work(mcast_wq, &group->work);
217*4882a593Smuzhiyun 	}
218*4882a593Smuzhiyun 	spin_unlock_irqrestore(&group->lock, flags);
219*4882a593Smuzhiyun }
220*4882a593Smuzhiyun 
221*4882a593Smuzhiyun /*
222*4882a593Smuzhiyun  * A multicast group has four types of members: full member, non member,
223*4882a593Smuzhiyun  * sendonly non member and sendonly full member.
224*4882a593Smuzhiyun  * We need to keep track of the number of members of each
225*4882a593Smuzhiyun  * type based on their join state.  Adjust the number of members the belong to
226*4882a593Smuzhiyun  * the specified join states.
227*4882a593Smuzhiyun  */
adjust_membership(struct mcast_group * group,u8 join_state,int inc)228*4882a593Smuzhiyun static void adjust_membership(struct mcast_group *group, u8 join_state, int inc)
229*4882a593Smuzhiyun {
230*4882a593Smuzhiyun 	int i;
231*4882a593Smuzhiyun 
232*4882a593Smuzhiyun 	for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++, join_state >>= 1)
233*4882a593Smuzhiyun 		if (join_state & 0x1)
234*4882a593Smuzhiyun 			group->members[i] += inc;
235*4882a593Smuzhiyun }
236*4882a593Smuzhiyun 
237*4882a593Smuzhiyun /*
238*4882a593Smuzhiyun  * If a multicast group has zero members left for a particular join state, but
239*4882a593Smuzhiyun  * the group is still a member with the SA, we need to leave that join state.
240*4882a593Smuzhiyun  * Determine which join states we still belong to, but that do not have any
241*4882a593Smuzhiyun  * active members.
242*4882a593Smuzhiyun  */
get_leave_state(struct mcast_group * group)243*4882a593Smuzhiyun static u8 get_leave_state(struct mcast_group *group)
244*4882a593Smuzhiyun {
245*4882a593Smuzhiyun 	u8 leave_state = 0;
246*4882a593Smuzhiyun 	int i;
247*4882a593Smuzhiyun 
248*4882a593Smuzhiyun 	for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++)
249*4882a593Smuzhiyun 		if (!group->members[i])
250*4882a593Smuzhiyun 			leave_state |= (0x1 << i);
251*4882a593Smuzhiyun 
252*4882a593Smuzhiyun 	return leave_state & group->rec.join_state;
253*4882a593Smuzhiyun }
254*4882a593Smuzhiyun 
check_selector(ib_sa_comp_mask comp_mask,ib_sa_comp_mask selector_mask,ib_sa_comp_mask value_mask,u8 selector,u8 src_value,u8 dst_value)255*4882a593Smuzhiyun static int check_selector(ib_sa_comp_mask comp_mask,
256*4882a593Smuzhiyun 			  ib_sa_comp_mask selector_mask,
257*4882a593Smuzhiyun 			  ib_sa_comp_mask value_mask,
258*4882a593Smuzhiyun 			  u8 selector, u8 src_value, u8 dst_value)
259*4882a593Smuzhiyun {
260*4882a593Smuzhiyun 	int err;
261*4882a593Smuzhiyun 
262*4882a593Smuzhiyun 	if (!(comp_mask & selector_mask) || !(comp_mask & value_mask))
263*4882a593Smuzhiyun 		return 0;
264*4882a593Smuzhiyun 
265*4882a593Smuzhiyun 	switch (selector) {
266*4882a593Smuzhiyun 	case IB_SA_GT:
267*4882a593Smuzhiyun 		err = (src_value <= dst_value);
268*4882a593Smuzhiyun 		break;
269*4882a593Smuzhiyun 	case IB_SA_LT:
270*4882a593Smuzhiyun 		err = (src_value >= dst_value);
271*4882a593Smuzhiyun 		break;
272*4882a593Smuzhiyun 	case IB_SA_EQ:
273*4882a593Smuzhiyun 		err = (src_value != dst_value);
274*4882a593Smuzhiyun 		break;
275*4882a593Smuzhiyun 	default:
276*4882a593Smuzhiyun 		err = 0;
277*4882a593Smuzhiyun 		break;
278*4882a593Smuzhiyun 	}
279*4882a593Smuzhiyun 
280*4882a593Smuzhiyun 	return err;
281*4882a593Smuzhiyun }
282*4882a593Smuzhiyun 
cmp_rec(struct ib_sa_mcmember_rec * src,struct ib_sa_mcmember_rec * dst,ib_sa_comp_mask comp_mask)283*4882a593Smuzhiyun static int cmp_rec(struct ib_sa_mcmember_rec *src,
284*4882a593Smuzhiyun 		   struct ib_sa_mcmember_rec *dst, ib_sa_comp_mask comp_mask)
285*4882a593Smuzhiyun {
286*4882a593Smuzhiyun 	/* MGID must already match */
287*4882a593Smuzhiyun 
288*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_PORT_GID &&
289*4882a593Smuzhiyun 	    memcmp(&src->port_gid, &dst->port_gid, sizeof src->port_gid))
290*4882a593Smuzhiyun 		return -EINVAL;
291*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_QKEY && src->qkey != dst->qkey)
292*4882a593Smuzhiyun 		return -EINVAL;
293*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_MLID && src->mlid != dst->mlid)
294*4882a593Smuzhiyun 		return -EINVAL;
295*4882a593Smuzhiyun 	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_MTU_SELECTOR,
296*4882a593Smuzhiyun 			   IB_SA_MCMEMBER_REC_MTU, dst->mtu_selector,
297*4882a593Smuzhiyun 			   src->mtu, dst->mtu))
298*4882a593Smuzhiyun 		return -EINVAL;
299*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_TRAFFIC_CLASS &&
300*4882a593Smuzhiyun 	    src->traffic_class != dst->traffic_class)
301*4882a593Smuzhiyun 		return -EINVAL;
302*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_PKEY && src->pkey != dst->pkey)
303*4882a593Smuzhiyun 		return -EINVAL;
304*4882a593Smuzhiyun 	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_RATE_SELECTOR,
305*4882a593Smuzhiyun 			   IB_SA_MCMEMBER_REC_RATE, dst->rate_selector,
306*4882a593Smuzhiyun 			   src->rate, dst->rate))
307*4882a593Smuzhiyun 		return -EINVAL;
308*4882a593Smuzhiyun 	if (check_selector(comp_mask,
309*4882a593Smuzhiyun 			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME_SELECTOR,
310*4882a593Smuzhiyun 			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME,
311*4882a593Smuzhiyun 			   dst->packet_life_time_selector,
312*4882a593Smuzhiyun 			   src->packet_life_time, dst->packet_life_time))
313*4882a593Smuzhiyun 		return -EINVAL;
314*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_SL && src->sl != dst->sl)
315*4882a593Smuzhiyun 		return -EINVAL;
316*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_FLOW_LABEL &&
317*4882a593Smuzhiyun 	    src->flow_label != dst->flow_label)
318*4882a593Smuzhiyun 		return -EINVAL;
319*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_HOP_LIMIT &&
320*4882a593Smuzhiyun 	    src->hop_limit != dst->hop_limit)
321*4882a593Smuzhiyun 		return -EINVAL;
322*4882a593Smuzhiyun 	if (comp_mask & IB_SA_MCMEMBER_REC_SCOPE && src->scope != dst->scope)
323*4882a593Smuzhiyun 		return -EINVAL;
324*4882a593Smuzhiyun 
325*4882a593Smuzhiyun 	/* join_state checked separately, proxy_join ignored */
326*4882a593Smuzhiyun 
327*4882a593Smuzhiyun 	return 0;
328*4882a593Smuzhiyun }
329*4882a593Smuzhiyun 
send_join(struct mcast_group * group,struct mcast_member * member)330*4882a593Smuzhiyun static int send_join(struct mcast_group *group, struct mcast_member *member)
331*4882a593Smuzhiyun {
332*4882a593Smuzhiyun 	struct mcast_port *port = group->port;
333*4882a593Smuzhiyun 	int ret;
334*4882a593Smuzhiyun 
335*4882a593Smuzhiyun 	group->last_join = member;
336*4882a593Smuzhiyun 	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
337*4882a593Smuzhiyun 				       port->port_num, IB_MGMT_METHOD_SET,
338*4882a593Smuzhiyun 				       &member->multicast.rec,
339*4882a593Smuzhiyun 				       member->multicast.comp_mask,
340*4882a593Smuzhiyun 				       3000, GFP_KERNEL, join_handler, group,
341*4882a593Smuzhiyun 				       &group->query);
342*4882a593Smuzhiyun 	return (ret > 0) ? 0 : ret;
343*4882a593Smuzhiyun }
344*4882a593Smuzhiyun 
send_leave(struct mcast_group * group,u8 leave_state)345*4882a593Smuzhiyun static int send_leave(struct mcast_group *group, u8 leave_state)
346*4882a593Smuzhiyun {
347*4882a593Smuzhiyun 	struct mcast_port *port = group->port;
348*4882a593Smuzhiyun 	struct ib_sa_mcmember_rec rec;
349*4882a593Smuzhiyun 	int ret;
350*4882a593Smuzhiyun 
351*4882a593Smuzhiyun 	rec = group->rec;
352*4882a593Smuzhiyun 	rec.join_state = leave_state;
353*4882a593Smuzhiyun 	group->leave_state = leave_state;
354*4882a593Smuzhiyun 
355*4882a593Smuzhiyun 	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
356*4882a593Smuzhiyun 				       port->port_num, IB_SA_METHOD_DELETE, &rec,
357*4882a593Smuzhiyun 				       IB_SA_MCMEMBER_REC_MGID     |
358*4882a593Smuzhiyun 				       IB_SA_MCMEMBER_REC_PORT_GID |
359*4882a593Smuzhiyun 				       IB_SA_MCMEMBER_REC_JOIN_STATE,
360*4882a593Smuzhiyun 				       3000, GFP_KERNEL, leave_handler,
361*4882a593Smuzhiyun 				       group, &group->query);
362*4882a593Smuzhiyun 	return (ret > 0) ? 0 : ret;
363*4882a593Smuzhiyun }
364*4882a593Smuzhiyun 
join_group(struct mcast_group * group,struct mcast_member * member,u8 join_state)365*4882a593Smuzhiyun static void join_group(struct mcast_group *group, struct mcast_member *member,
366*4882a593Smuzhiyun 		       u8 join_state)
367*4882a593Smuzhiyun {
368*4882a593Smuzhiyun 	member->state = MCAST_MEMBER;
369*4882a593Smuzhiyun 	adjust_membership(group, join_state, 1);
370*4882a593Smuzhiyun 	group->rec.join_state |= join_state;
371*4882a593Smuzhiyun 	member->multicast.rec = group->rec;
372*4882a593Smuzhiyun 	member->multicast.rec.join_state = join_state;
373*4882a593Smuzhiyun 	list_move(&member->list, &group->active_list);
374*4882a593Smuzhiyun }
375*4882a593Smuzhiyun 
fail_join(struct mcast_group * group,struct mcast_member * member,int status)376*4882a593Smuzhiyun static int fail_join(struct mcast_group *group, struct mcast_member *member,
377*4882a593Smuzhiyun 		     int status)
378*4882a593Smuzhiyun {
379*4882a593Smuzhiyun 	spin_lock_irq(&group->lock);
380*4882a593Smuzhiyun 	list_del_init(&member->list);
381*4882a593Smuzhiyun 	spin_unlock_irq(&group->lock);
382*4882a593Smuzhiyun 	return member->multicast.callback(status, &member->multicast);
383*4882a593Smuzhiyun }
384*4882a593Smuzhiyun 
process_group_error(struct mcast_group * group)385*4882a593Smuzhiyun static void process_group_error(struct mcast_group *group)
386*4882a593Smuzhiyun {
387*4882a593Smuzhiyun 	struct mcast_member *member;
388*4882a593Smuzhiyun 	int ret = 0;
389*4882a593Smuzhiyun 	u16 pkey_index;
390*4882a593Smuzhiyun 
391*4882a593Smuzhiyun 	if (group->state == MCAST_PKEY_EVENT)
392*4882a593Smuzhiyun 		ret = ib_find_pkey(group->port->dev->device,
393*4882a593Smuzhiyun 				   group->port->port_num,
394*4882a593Smuzhiyun 				   be16_to_cpu(group->rec.pkey), &pkey_index);
395*4882a593Smuzhiyun 
396*4882a593Smuzhiyun 	spin_lock_irq(&group->lock);
397*4882a593Smuzhiyun 	if (group->state == MCAST_PKEY_EVENT && !ret &&
398*4882a593Smuzhiyun 	    group->pkey_index == pkey_index)
399*4882a593Smuzhiyun 		goto out;
400*4882a593Smuzhiyun 
401*4882a593Smuzhiyun 	while (!list_empty(&group->active_list)) {
402*4882a593Smuzhiyun 		member = list_entry(group->active_list.next,
403*4882a593Smuzhiyun 				    struct mcast_member, list);
404*4882a593Smuzhiyun 		atomic_inc(&member->refcount);
405*4882a593Smuzhiyun 		list_del_init(&member->list);
406*4882a593Smuzhiyun 		adjust_membership(group, member->multicast.rec.join_state, -1);
407*4882a593Smuzhiyun 		member->state = MCAST_ERROR;
408*4882a593Smuzhiyun 		spin_unlock_irq(&group->lock);
409*4882a593Smuzhiyun 
410*4882a593Smuzhiyun 		ret = member->multicast.callback(-ENETRESET,
411*4882a593Smuzhiyun 						 &member->multicast);
412*4882a593Smuzhiyun 		deref_member(member);
413*4882a593Smuzhiyun 		if (ret)
414*4882a593Smuzhiyun 			ib_sa_free_multicast(&member->multicast);
415*4882a593Smuzhiyun 		spin_lock_irq(&group->lock);
416*4882a593Smuzhiyun 	}
417*4882a593Smuzhiyun 
418*4882a593Smuzhiyun 	group->rec.join_state = 0;
419*4882a593Smuzhiyun out:
420*4882a593Smuzhiyun 	group->state = MCAST_BUSY;
421*4882a593Smuzhiyun 	spin_unlock_irq(&group->lock);
422*4882a593Smuzhiyun }
423*4882a593Smuzhiyun 
mcast_work_handler(struct work_struct * work)424*4882a593Smuzhiyun static void mcast_work_handler(struct work_struct *work)
425*4882a593Smuzhiyun {
426*4882a593Smuzhiyun 	struct mcast_group *group;
427*4882a593Smuzhiyun 	struct mcast_member *member;
428*4882a593Smuzhiyun 	struct ib_sa_multicast *multicast;
429*4882a593Smuzhiyun 	int status, ret;
430*4882a593Smuzhiyun 	u8 join_state;
431*4882a593Smuzhiyun 
432*4882a593Smuzhiyun 	group = container_of(work, typeof(*group), work);
433*4882a593Smuzhiyun retest:
434*4882a593Smuzhiyun 	spin_lock_irq(&group->lock);
435*4882a593Smuzhiyun 	while (!list_empty(&group->pending_list) ||
436*4882a593Smuzhiyun 	       (group->state != MCAST_BUSY)) {
437*4882a593Smuzhiyun 
438*4882a593Smuzhiyun 		if (group->state != MCAST_BUSY) {
439*4882a593Smuzhiyun 			spin_unlock_irq(&group->lock);
440*4882a593Smuzhiyun 			process_group_error(group);
441*4882a593Smuzhiyun 			goto retest;
442*4882a593Smuzhiyun 		}
443*4882a593Smuzhiyun 
444*4882a593Smuzhiyun 		member = list_entry(group->pending_list.next,
445*4882a593Smuzhiyun 				    struct mcast_member, list);
446*4882a593Smuzhiyun 		multicast = &member->multicast;
447*4882a593Smuzhiyun 		join_state = multicast->rec.join_state;
448*4882a593Smuzhiyun 		atomic_inc(&member->refcount);
449*4882a593Smuzhiyun 
450*4882a593Smuzhiyun 		if (join_state == (group->rec.join_state & join_state)) {
451*4882a593Smuzhiyun 			status = cmp_rec(&group->rec, &multicast->rec,
452*4882a593Smuzhiyun 					 multicast->comp_mask);
453*4882a593Smuzhiyun 			if (!status)
454*4882a593Smuzhiyun 				join_group(group, member, join_state);
455*4882a593Smuzhiyun 			else
456*4882a593Smuzhiyun 				list_del_init(&member->list);
457*4882a593Smuzhiyun 			spin_unlock_irq(&group->lock);
458*4882a593Smuzhiyun 			ret = multicast->callback(status, multicast);
459*4882a593Smuzhiyun 		} else {
460*4882a593Smuzhiyun 			spin_unlock_irq(&group->lock);
461*4882a593Smuzhiyun 			status = send_join(group, member);
462*4882a593Smuzhiyun 			if (!status) {
463*4882a593Smuzhiyun 				deref_member(member);
464*4882a593Smuzhiyun 				return;
465*4882a593Smuzhiyun 			}
466*4882a593Smuzhiyun 			ret = fail_join(group, member, status);
467*4882a593Smuzhiyun 		}
468*4882a593Smuzhiyun 
469*4882a593Smuzhiyun 		deref_member(member);
470*4882a593Smuzhiyun 		if (ret)
471*4882a593Smuzhiyun 			ib_sa_free_multicast(&member->multicast);
472*4882a593Smuzhiyun 		spin_lock_irq(&group->lock);
473*4882a593Smuzhiyun 	}
474*4882a593Smuzhiyun 
475*4882a593Smuzhiyun 	join_state = get_leave_state(group);
476*4882a593Smuzhiyun 	if (join_state) {
477*4882a593Smuzhiyun 		group->rec.join_state &= ~join_state;
478*4882a593Smuzhiyun 		spin_unlock_irq(&group->lock);
479*4882a593Smuzhiyun 		if (send_leave(group, join_state))
480*4882a593Smuzhiyun 			goto retest;
481*4882a593Smuzhiyun 	} else {
482*4882a593Smuzhiyun 		group->state = MCAST_IDLE;
483*4882a593Smuzhiyun 		spin_unlock_irq(&group->lock);
484*4882a593Smuzhiyun 		release_group(group);
485*4882a593Smuzhiyun 	}
486*4882a593Smuzhiyun }
487*4882a593Smuzhiyun 
488*4882a593Smuzhiyun /*
489*4882a593Smuzhiyun  * Fail a join request if it is still active - at the head of the pending queue.
490*4882a593Smuzhiyun  */
process_join_error(struct mcast_group * group,int status)491*4882a593Smuzhiyun static void process_join_error(struct mcast_group *group, int status)
492*4882a593Smuzhiyun {
493*4882a593Smuzhiyun 	struct mcast_member *member;
494*4882a593Smuzhiyun 	int ret;
495*4882a593Smuzhiyun 
496*4882a593Smuzhiyun 	spin_lock_irq(&group->lock);
497*4882a593Smuzhiyun 	member = list_entry(group->pending_list.next,
498*4882a593Smuzhiyun 			    struct mcast_member, list);
499*4882a593Smuzhiyun 	if (group->last_join == member) {
500*4882a593Smuzhiyun 		atomic_inc(&member->refcount);
501*4882a593Smuzhiyun 		list_del_init(&member->list);
502*4882a593Smuzhiyun 		spin_unlock_irq(&group->lock);
503*4882a593Smuzhiyun 		ret = member->multicast.callback(status, &member->multicast);
504*4882a593Smuzhiyun 		deref_member(member);
505*4882a593Smuzhiyun 		if (ret)
506*4882a593Smuzhiyun 			ib_sa_free_multicast(&member->multicast);
507*4882a593Smuzhiyun 	} else
508*4882a593Smuzhiyun 		spin_unlock_irq(&group->lock);
509*4882a593Smuzhiyun }
510*4882a593Smuzhiyun 
join_handler(int status,struct ib_sa_mcmember_rec * rec,void * context)511*4882a593Smuzhiyun static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
512*4882a593Smuzhiyun 			 void *context)
513*4882a593Smuzhiyun {
514*4882a593Smuzhiyun 	struct mcast_group *group = context;
515*4882a593Smuzhiyun 	u16 pkey_index = MCAST_INVALID_PKEY_INDEX;
516*4882a593Smuzhiyun 
517*4882a593Smuzhiyun 	if (status)
518*4882a593Smuzhiyun 		process_join_error(group, status);
519*4882a593Smuzhiyun 	else {
520*4882a593Smuzhiyun 		int mgids_changed, is_mgid0;
521*4882a593Smuzhiyun 
522*4882a593Smuzhiyun 		if (ib_find_pkey(group->port->dev->device,
523*4882a593Smuzhiyun 				 group->port->port_num, be16_to_cpu(rec->pkey),
524*4882a593Smuzhiyun 				 &pkey_index))
525*4882a593Smuzhiyun 			pkey_index = MCAST_INVALID_PKEY_INDEX;
526*4882a593Smuzhiyun 
527*4882a593Smuzhiyun 		spin_lock_irq(&group->port->lock);
528*4882a593Smuzhiyun 		if (group->state == MCAST_BUSY &&
529*4882a593Smuzhiyun 		    group->pkey_index == MCAST_INVALID_PKEY_INDEX)
530*4882a593Smuzhiyun 			group->pkey_index = pkey_index;
531*4882a593Smuzhiyun 		mgids_changed = memcmp(&rec->mgid, &group->rec.mgid,
532*4882a593Smuzhiyun 				       sizeof(group->rec.mgid));
533*4882a593Smuzhiyun 		group->rec = *rec;
534*4882a593Smuzhiyun 		if (mgids_changed) {
535*4882a593Smuzhiyun 			rb_erase(&group->node, &group->port->table);
536*4882a593Smuzhiyun 			is_mgid0 = !memcmp(&mgid0, &group->rec.mgid,
537*4882a593Smuzhiyun 					   sizeof(mgid0));
538*4882a593Smuzhiyun 			mcast_insert(group->port, group, is_mgid0);
539*4882a593Smuzhiyun 		}
540*4882a593Smuzhiyun 		spin_unlock_irq(&group->port->lock);
541*4882a593Smuzhiyun 	}
542*4882a593Smuzhiyun 	mcast_work_handler(&group->work);
543*4882a593Smuzhiyun }
544*4882a593Smuzhiyun 
leave_handler(int status,struct ib_sa_mcmember_rec * rec,void * context)545*4882a593Smuzhiyun static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
546*4882a593Smuzhiyun 			  void *context)
547*4882a593Smuzhiyun {
548*4882a593Smuzhiyun 	struct mcast_group *group = context;
549*4882a593Smuzhiyun 
550*4882a593Smuzhiyun 	if (status && group->retries > 0 &&
551*4882a593Smuzhiyun 	    !send_leave(group, group->leave_state))
552*4882a593Smuzhiyun 		group->retries--;
553*4882a593Smuzhiyun 	else
554*4882a593Smuzhiyun 		mcast_work_handler(&group->work);
555*4882a593Smuzhiyun }
556*4882a593Smuzhiyun 
acquire_group(struct mcast_port * port,union ib_gid * mgid,gfp_t gfp_mask)557*4882a593Smuzhiyun static struct mcast_group *acquire_group(struct mcast_port *port,
558*4882a593Smuzhiyun 					 union ib_gid *mgid, gfp_t gfp_mask)
559*4882a593Smuzhiyun {
560*4882a593Smuzhiyun 	struct mcast_group *group, *cur_group;
561*4882a593Smuzhiyun 	unsigned long flags;
562*4882a593Smuzhiyun 	int is_mgid0;
563*4882a593Smuzhiyun 
564*4882a593Smuzhiyun 	is_mgid0 = !memcmp(&mgid0, mgid, sizeof mgid0);
565*4882a593Smuzhiyun 	if (!is_mgid0) {
566*4882a593Smuzhiyun 		spin_lock_irqsave(&port->lock, flags);
567*4882a593Smuzhiyun 		group = mcast_find(port, mgid);
568*4882a593Smuzhiyun 		if (group)
569*4882a593Smuzhiyun 			goto found;
570*4882a593Smuzhiyun 		spin_unlock_irqrestore(&port->lock, flags);
571*4882a593Smuzhiyun 	}
572*4882a593Smuzhiyun 
573*4882a593Smuzhiyun 	group = kzalloc(sizeof *group, gfp_mask);
574*4882a593Smuzhiyun 	if (!group)
575*4882a593Smuzhiyun 		return NULL;
576*4882a593Smuzhiyun 
577*4882a593Smuzhiyun 	group->retries = 3;
578*4882a593Smuzhiyun 	group->port = port;
579*4882a593Smuzhiyun 	group->rec.mgid = *mgid;
580*4882a593Smuzhiyun 	group->pkey_index = MCAST_INVALID_PKEY_INDEX;
581*4882a593Smuzhiyun 	INIT_LIST_HEAD(&group->pending_list);
582*4882a593Smuzhiyun 	INIT_LIST_HEAD(&group->active_list);
583*4882a593Smuzhiyun 	INIT_WORK(&group->work, mcast_work_handler);
584*4882a593Smuzhiyun 	spin_lock_init(&group->lock);
585*4882a593Smuzhiyun 
586*4882a593Smuzhiyun 	spin_lock_irqsave(&port->lock, flags);
587*4882a593Smuzhiyun 	cur_group = mcast_insert(port, group, is_mgid0);
588*4882a593Smuzhiyun 	if (cur_group) {
589*4882a593Smuzhiyun 		kfree(group);
590*4882a593Smuzhiyun 		group = cur_group;
591*4882a593Smuzhiyun 	} else
592*4882a593Smuzhiyun 		atomic_inc(&port->refcount);
593*4882a593Smuzhiyun found:
594*4882a593Smuzhiyun 	atomic_inc(&group->refcount);
595*4882a593Smuzhiyun 	spin_unlock_irqrestore(&port->lock, flags);
596*4882a593Smuzhiyun 	return group;
597*4882a593Smuzhiyun }
598*4882a593Smuzhiyun 
599*4882a593Smuzhiyun /*
600*4882a593Smuzhiyun  * We serialize all join requests to a single group to make our lives much
601*4882a593Smuzhiyun  * easier.  Otherwise, two users could try to join the same group
602*4882a593Smuzhiyun  * simultaneously, with different configurations, one could leave while the
603*4882a593Smuzhiyun  * join is in progress, etc., which makes locking around error recovery
604*4882a593Smuzhiyun  * difficult.
605*4882a593Smuzhiyun  */
606*4882a593Smuzhiyun struct ib_sa_multicast *
ib_sa_join_multicast(struct ib_sa_client * client,struct ib_device * device,u8 port_num,struct ib_sa_mcmember_rec * rec,ib_sa_comp_mask comp_mask,gfp_t gfp_mask,int (* callback)(int status,struct ib_sa_multicast * multicast),void * context)607*4882a593Smuzhiyun ib_sa_join_multicast(struct ib_sa_client *client,
608*4882a593Smuzhiyun 		     struct ib_device *device, u8 port_num,
609*4882a593Smuzhiyun 		     struct ib_sa_mcmember_rec *rec,
610*4882a593Smuzhiyun 		     ib_sa_comp_mask comp_mask, gfp_t gfp_mask,
611*4882a593Smuzhiyun 		     int (*callback)(int status,
612*4882a593Smuzhiyun 				     struct ib_sa_multicast *multicast),
613*4882a593Smuzhiyun 		     void *context)
614*4882a593Smuzhiyun {
615*4882a593Smuzhiyun 	struct mcast_device *dev;
616*4882a593Smuzhiyun 	struct mcast_member *member;
617*4882a593Smuzhiyun 	struct ib_sa_multicast *multicast;
618*4882a593Smuzhiyun 	int ret;
619*4882a593Smuzhiyun 
620*4882a593Smuzhiyun 	dev = ib_get_client_data(device, &mcast_client);
621*4882a593Smuzhiyun 	if (!dev)
622*4882a593Smuzhiyun 		return ERR_PTR(-ENODEV);
623*4882a593Smuzhiyun 
624*4882a593Smuzhiyun 	member = kmalloc(sizeof *member, gfp_mask);
625*4882a593Smuzhiyun 	if (!member)
626*4882a593Smuzhiyun 		return ERR_PTR(-ENOMEM);
627*4882a593Smuzhiyun 
628*4882a593Smuzhiyun 	ib_sa_client_get(client);
629*4882a593Smuzhiyun 	member->client = client;
630*4882a593Smuzhiyun 	member->multicast.rec = *rec;
631*4882a593Smuzhiyun 	member->multicast.comp_mask = comp_mask;
632*4882a593Smuzhiyun 	member->multicast.callback = callback;
633*4882a593Smuzhiyun 	member->multicast.context = context;
634*4882a593Smuzhiyun 	init_completion(&member->comp);
635*4882a593Smuzhiyun 	atomic_set(&member->refcount, 1);
636*4882a593Smuzhiyun 	member->state = MCAST_JOINING;
637*4882a593Smuzhiyun 
638*4882a593Smuzhiyun 	member->group = acquire_group(&dev->port[port_num - dev->start_port],
639*4882a593Smuzhiyun 				      &rec->mgid, gfp_mask);
640*4882a593Smuzhiyun 	if (!member->group) {
641*4882a593Smuzhiyun 		ret = -ENOMEM;
642*4882a593Smuzhiyun 		goto err;
643*4882a593Smuzhiyun 	}
644*4882a593Smuzhiyun 
645*4882a593Smuzhiyun 	/*
646*4882a593Smuzhiyun 	 * The user will get the multicast structure in their callback.  They
647*4882a593Smuzhiyun 	 * could then free the multicast structure before we can return from
648*4882a593Smuzhiyun 	 * this routine.  So we save the pointer to return before queuing
649*4882a593Smuzhiyun 	 * any callback.
650*4882a593Smuzhiyun 	 */
651*4882a593Smuzhiyun 	multicast = &member->multicast;
652*4882a593Smuzhiyun 	queue_join(member);
653*4882a593Smuzhiyun 	return multicast;
654*4882a593Smuzhiyun 
655*4882a593Smuzhiyun err:
656*4882a593Smuzhiyun 	ib_sa_client_put(client);
657*4882a593Smuzhiyun 	kfree(member);
658*4882a593Smuzhiyun 	return ERR_PTR(ret);
659*4882a593Smuzhiyun }
660*4882a593Smuzhiyun EXPORT_SYMBOL(ib_sa_join_multicast);
661*4882a593Smuzhiyun 
ib_sa_free_multicast(struct ib_sa_multicast * multicast)662*4882a593Smuzhiyun void ib_sa_free_multicast(struct ib_sa_multicast *multicast)
663*4882a593Smuzhiyun {
664*4882a593Smuzhiyun 	struct mcast_member *member;
665*4882a593Smuzhiyun 	struct mcast_group *group;
666*4882a593Smuzhiyun 
667*4882a593Smuzhiyun 	member = container_of(multicast, struct mcast_member, multicast);
668*4882a593Smuzhiyun 	group = member->group;
669*4882a593Smuzhiyun 
670*4882a593Smuzhiyun 	spin_lock_irq(&group->lock);
671*4882a593Smuzhiyun 	if (member->state == MCAST_MEMBER)
672*4882a593Smuzhiyun 		adjust_membership(group, multicast->rec.join_state, -1);
673*4882a593Smuzhiyun 
674*4882a593Smuzhiyun 	list_del_init(&member->list);
675*4882a593Smuzhiyun 
676*4882a593Smuzhiyun 	if (group->state == MCAST_IDLE) {
677*4882a593Smuzhiyun 		group->state = MCAST_BUSY;
678*4882a593Smuzhiyun 		spin_unlock_irq(&group->lock);
679*4882a593Smuzhiyun 		/* Continue to hold reference on group until callback */
680*4882a593Smuzhiyun 		queue_work(mcast_wq, &group->work);
681*4882a593Smuzhiyun 	} else {
682*4882a593Smuzhiyun 		spin_unlock_irq(&group->lock);
683*4882a593Smuzhiyun 		release_group(group);
684*4882a593Smuzhiyun 	}
685*4882a593Smuzhiyun 
686*4882a593Smuzhiyun 	deref_member(member);
687*4882a593Smuzhiyun 	wait_for_completion(&member->comp);
688*4882a593Smuzhiyun 	ib_sa_client_put(member->client);
689*4882a593Smuzhiyun 	kfree(member);
690*4882a593Smuzhiyun }
691*4882a593Smuzhiyun EXPORT_SYMBOL(ib_sa_free_multicast);
692*4882a593Smuzhiyun 
ib_sa_get_mcmember_rec(struct ib_device * device,u8 port_num,union ib_gid * mgid,struct ib_sa_mcmember_rec * rec)693*4882a593Smuzhiyun int ib_sa_get_mcmember_rec(struct ib_device *device, u8 port_num,
694*4882a593Smuzhiyun 			   union ib_gid *mgid, struct ib_sa_mcmember_rec *rec)
695*4882a593Smuzhiyun {
696*4882a593Smuzhiyun 	struct mcast_device *dev;
697*4882a593Smuzhiyun 	struct mcast_port *port;
698*4882a593Smuzhiyun 	struct mcast_group *group;
699*4882a593Smuzhiyun 	unsigned long flags;
700*4882a593Smuzhiyun 	int ret = 0;
701*4882a593Smuzhiyun 
702*4882a593Smuzhiyun 	dev = ib_get_client_data(device, &mcast_client);
703*4882a593Smuzhiyun 	if (!dev)
704*4882a593Smuzhiyun 		return -ENODEV;
705*4882a593Smuzhiyun 
706*4882a593Smuzhiyun 	port = &dev->port[port_num - dev->start_port];
707*4882a593Smuzhiyun 	spin_lock_irqsave(&port->lock, flags);
708*4882a593Smuzhiyun 	group = mcast_find(port, mgid);
709*4882a593Smuzhiyun 	if (group)
710*4882a593Smuzhiyun 		*rec = group->rec;
711*4882a593Smuzhiyun 	else
712*4882a593Smuzhiyun 		ret = -EADDRNOTAVAIL;
713*4882a593Smuzhiyun 	spin_unlock_irqrestore(&port->lock, flags);
714*4882a593Smuzhiyun 
715*4882a593Smuzhiyun 	return ret;
716*4882a593Smuzhiyun }
717*4882a593Smuzhiyun EXPORT_SYMBOL(ib_sa_get_mcmember_rec);
718*4882a593Smuzhiyun 
719*4882a593Smuzhiyun /**
720*4882a593Smuzhiyun  * ib_init_ah_from_mcmember - Initialize AH attribute from multicast
721*4882a593Smuzhiyun  * member record and gid of the device.
722*4882a593Smuzhiyun  * @device:	RDMA device
723*4882a593Smuzhiyun  * @port_num:	Port of the rdma device to consider
724*4882a593Smuzhiyun  * @ndev:	Optional netdevice, applicable only for RoCE
725*4882a593Smuzhiyun  * @gid_type:	GID type to consider
726*4882a593Smuzhiyun  * @ah_attr:	AH attribute to fillup on successful completion
727*4882a593Smuzhiyun  *
728*4882a593Smuzhiyun  * ib_init_ah_from_mcmember() initializes AH attribute based on multicast
729*4882a593Smuzhiyun  * member record and other device properties. On success the caller is
730*4882a593Smuzhiyun  * responsible to call rdma_destroy_ah_attr on the ah_attr. Returns 0 on
731*4882a593Smuzhiyun  * success or appropriate error code.
732*4882a593Smuzhiyun  *
733*4882a593Smuzhiyun  */
ib_init_ah_from_mcmember(struct ib_device * device,u8 port_num,struct ib_sa_mcmember_rec * rec,struct net_device * ndev,enum ib_gid_type gid_type,struct rdma_ah_attr * ah_attr)734*4882a593Smuzhiyun int ib_init_ah_from_mcmember(struct ib_device *device, u8 port_num,
735*4882a593Smuzhiyun 			     struct ib_sa_mcmember_rec *rec,
736*4882a593Smuzhiyun 			     struct net_device *ndev,
737*4882a593Smuzhiyun 			     enum ib_gid_type gid_type,
738*4882a593Smuzhiyun 			     struct rdma_ah_attr *ah_attr)
739*4882a593Smuzhiyun {
740*4882a593Smuzhiyun 	const struct ib_gid_attr *sgid_attr;
741*4882a593Smuzhiyun 
742*4882a593Smuzhiyun 	/* GID table is not based on the netdevice for IB link layer,
743*4882a593Smuzhiyun 	 * so ignore ndev during search.
744*4882a593Smuzhiyun 	 */
745*4882a593Smuzhiyun 	if (rdma_protocol_ib(device, port_num))
746*4882a593Smuzhiyun 		ndev = NULL;
747*4882a593Smuzhiyun 	else if (!rdma_protocol_roce(device, port_num))
748*4882a593Smuzhiyun 		return -EINVAL;
749*4882a593Smuzhiyun 
750*4882a593Smuzhiyun 	sgid_attr = rdma_find_gid_by_port(device, &rec->port_gid,
751*4882a593Smuzhiyun 					  gid_type, port_num, ndev);
752*4882a593Smuzhiyun 	if (IS_ERR(sgid_attr))
753*4882a593Smuzhiyun 		return PTR_ERR(sgid_attr);
754*4882a593Smuzhiyun 
755*4882a593Smuzhiyun 	memset(ah_attr, 0, sizeof(*ah_attr));
756*4882a593Smuzhiyun 	ah_attr->type = rdma_ah_find_type(device, port_num);
757*4882a593Smuzhiyun 
758*4882a593Smuzhiyun 	rdma_ah_set_dlid(ah_attr, be16_to_cpu(rec->mlid));
759*4882a593Smuzhiyun 	rdma_ah_set_sl(ah_attr, rec->sl);
760*4882a593Smuzhiyun 	rdma_ah_set_port_num(ah_attr, port_num);
761*4882a593Smuzhiyun 	rdma_ah_set_static_rate(ah_attr, rec->rate);
762*4882a593Smuzhiyun 	rdma_move_grh_sgid_attr(ah_attr, &rec->mgid,
763*4882a593Smuzhiyun 				be32_to_cpu(rec->flow_label),
764*4882a593Smuzhiyun 				rec->hop_limit,	rec->traffic_class,
765*4882a593Smuzhiyun 				sgid_attr);
766*4882a593Smuzhiyun 	return 0;
767*4882a593Smuzhiyun }
768*4882a593Smuzhiyun EXPORT_SYMBOL(ib_init_ah_from_mcmember);
769*4882a593Smuzhiyun 
mcast_groups_event(struct mcast_port * port,enum mcast_group_state state)770*4882a593Smuzhiyun static void mcast_groups_event(struct mcast_port *port,
771*4882a593Smuzhiyun 			       enum mcast_group_state state)
772*4882a593Smuzhiyun {
773*4882a593Smuzhiyun 	struct mcast_group *group;
774*4882a593Smuzhiyun 	struct rb_node *node;
775*4882a593Smuzhiyun 	unsigned long flags;
776*4882a593Smuzhiyun 
777*4882a593Smuzhiyun 	spin_lock_irqsave(&port->lock, flags);
778*4882a593Smuzhiyun 	for (node = rb_first(&port->table); node; node = rb_next(node)) {
779*4882a593Smuzhiyun 		group = rb_entry(node, struct mcast_group, node);
780*4882a593Smuzhiyun 		spin_lock(&group->lock);
781*4882a593Smuzhiyun 		if (group->state == MCAST_IDLE) {
782*4882a593Smuzhiyun 			atomic_inc(&group->refcount);
783*4882a593Smuzhiyun 			queue_work(mcast_wq, &group->work);
784*4882a593Smuzhiyun 		}
785*4882a593Smuzhiyun 		if (group->state != MCAST_GROUP_ERROR)
786*4882a593Smuzhiyun 			group->state = state;
787*4882a593Smuzhiyun 		spin_unlock(&group->lock);
788*4882a593Smuzhiyun 	}
789*4882a593Smuzhiyun 	spin_unlock_irqrestore(&port->lock, flags);
790*4882a593Smuzhiyun }
791*4882a593Smuzhiyun 
mcast_event_handler(struct ib_event_handler * handler,struct ib_event * event)792*4882a593Smuzhiyun static void mcast_event_handler(struct ib_event_handler *handler,
793*4882a593Smuzhiyun 				struct ib_event *event)
794*4882a593Smuzhiyun {
795*4882a593Smuzhiyun 	struct mcast_device *dev;
796*4882a593Smuzhiyun 	int index;
797*4882a593Smuzhiyun 
798*4882a593Smuzhiyun 	dev = container_of(handler, struct mcast_device, event_handler);
799*4882a593Smuzhiyun 	if (!rdma_cap_ib_mcast(dev->device, event->element.port_num))
800*4882a593Smuzhiyun 		return;
801*4882a593Smuzhiyun 
802*4882a593Smuzhiyun 	index = event->element.port_num - dev->start_port;
803*4882a593Smuzhiyun 
804*4882a593Smuzhiyun 	switch (event->event) {
805*4882a593Smuzhiyun 	case IB_EVENT_PORT_ERR:
806*4882a593Smuzhiyun 	case IB_EVENT_LID_CHANGE:
807*4882a593Smuzhiyun 	case IB_EVENT_CLIENT_REREGISTER:
808*4882a593Smuzhiyun 		mcast_groups_event(&dev->port[index], MCAST_GROUP_ERROR);
809*4882a593Smuzhiyun 		break;
810*4882a593Smuzhiyun 	case IB_EVENT_PKEY_CHANGE:
811*4882a593Smuzhiyun 		mcast_groups_event(&dev->port[index], MCAST_PKEY_EVENT);
812*4882a593Smuzhiyun 		break;
813*4882a593Smuzhiyun 	default:
814*4882a593Smuzhiyun 		break;
815*4882a593Smuzhiyun 	}
816*4882a593Smuzhiyun }
817*4882a593Smuzhiyun 
mcast_add_one(struct ib_device * device)818*4882a593Smuzhiyun static int mcast_add_one(struct ib_device *device)
819*4882a593Smuzhiyun {
820*4882a593Smuzhiyun 	struct mcast_device *dev;
821*4882a593Smuzhiyun 	struct mcast_port *port;
822*4882a593Smuzhiyun 	int i;
823*4882a593Smuzhiyun 	int count = 0;
824*4882a593Smuzhiyun 
825*4882a593Smuzhiyun 	dev = kmalloc(struct_size(dev, port, device->phys_port_cnt),
826*4882a593Smuzhiyun 		      GFP_KERNEL);
827*4882a593Smuzhiyun 	if (!dev)
828*4882a593Smuzhiyun 		return -ENOMEM;
829*4882a593Smuzhiyun 
830*4882a593Smuzhiyun 	dev->start_port = rdma_start_port(device);
831*4882a593Smuzhiyun 	dev->end_port = rdma_end_port(device);
832*4882a593Smuzhiyun 
833*4882a593Smuzhiyun 	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
834*4882a593Smuzhiyun 		if (!rdma_cap_ib_mcast(device, dev->start_port + i))
835*4882a593Smuzhiyun 			continue;
836*4882a593Smuzhiyun 		port = &dev->port[i];
837*4882a593Smuzhiyun 		port->dev = dev;
838*4882a593Smuzhiyun 		port->port_num = dev->start_port + i;
839*4882a593Smuzhiyun 		spin_lock_init(&port->lock);
840*4882a593Smuzhiyun 		port->table = RB_ROOT;
841*4882a593Smuzhiyun 		init_completion(&port->comp);
842*4882a593Smuzhiyun 		atomic_set(&port->refcount, 1);
843*4882a593Smuzhiyun 		++count;
844*4882a593Smuzhiyun 	}
845*4882a593Smuzhiyun 
846*4882a593Smuzhiyun 	if (!count) {
847*4882a593Smuzhiyun 		kfree(dev);
848*4882a593Smuzhiyun 		return -EOPNOTSUPP;
849*4882a593Smuzhiyun 	}
850*4882a593Smuzhiyun 
851*4882a593Smuzhiyun 	dev->device = device;
852*4882a593Smuzhiyun 	ib_set_client_data(device, &mcast_client, dev);
853*4882a593Smuzhiyun 
854*4882a593Smuzhiyun 	INIT_IB_EVENT_HANDLER(&dev->event_handler, device, mcast_event_handler);
855*4882a593Smuzhiyun 	ib_register_event_handler(&dev->event_handler);
856*4882a593Smuzhiyun 	return 0;
857*4882a593Smuzhiyun }
858*4882a593Smuzhiyun 
mcast_remove_one(struct ib_device * device,void * client_data)859*4882a593Smuzhiyun static void mcast_remove_one(struct ib_device *device, void *client_data)
860*4882a593Smuzhiyun {
861*4882a593Smuzhiyun 	struct mcast_device *dev = client_data;
862*4882a593Smuzhiyun 	struct mcast_port *port;
863*4882a593Smuzhiyun 	int i;
864*4882a593Smuzhiyun 
865*4882a593Smuzhiyun 	ib_unregister_event_handler(&dev->event_handler);
866*4882a593Smuzhiyun 	flush_workqueue(mcast_wq);
867*4882a593Smuzhiyun 
868*4882a593Smuzhiyun 	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
869*4882a593Smuzhiyun 		if (rdma_cap_ib_mcast(device, dev->start_port + i)) {
870*4882a593Smuzhiyun 			port = &dev->port[i];
871*4882a593Smuzhiyun 			deref_port(port);
872*4882a593Smuzhiyun 			wait_for_completion(&port->comp);
873*4882a593Smuzhiyun 		}
874*4882a593Smuzhiyun 	}
875*4882a593Smuzhiyun 
876*4882a593Smuzhiyun 	kfree(dev);
877*4882a593Smuzhiyun }
878*4882a593Smuzhiyun 
mcast_init(void)879*4882a593Smuzhiyun int mcast_init(void)
880*4882a593Smuzhiyun {
881*4882a593Smuzhiyun 	int ret;
882*4882a593Smuzhiyun 
883*4882a593Smuzhiyun 	mcast_wq = alloc_ordered_workqueue("ib_mcast", WQ_MEM_RECLAIM);
884*4882a593Smuzhiyun 	if (!mcast_wq)
885*4882a593Smuzhiyun 		return -ENOMEM;
886*4882a593Smuzhiyun 
887*4882a593Smuzhiyun 	ib_sa_register_client(&sa_client);
888*4882a593Smuzhiyun 
889*4882a593Smuzhiyun 	ret = ib_register_client(&mcast_client);
890*4882a593Smuzhiyun 	if (ret)
891*4882a593Smuzhiyun 		goto err;
892*4882a593Smuzhiyun 	return 0;
893*4882a593Smuzhiyun 
894*4882a593Smuzhiyun err:
895*4882a593Smuzhiyun 	ib_sa_unregister_client(&sa_client);
896*4882a593Smuzhiyun 	destroy_workqueue(mcast_wq);
897*4882a593Smuzhiyun 	return ret;
898*4882a593Smuzhiyun }
899*4882a593Smuzhiyun 
mcast_cleanup(void)900*4882a593Smuzhiyun void mcast_cleanup(void)
901*4882a593Smuzhiyun {
902*4882a593Smuzhiyun 	ib_unregister_client(&mcast_client);
903*4882a593Smuzhiyun 	ib_sa_unregister_client(&sa_client);
904*4882a593Smuzhiyun 	destroy_workqueue(mcast_wq);
905*4882a593Smuzhiyun }
906