xref: /optee_os/core/arch/riscv/kernel/sbi_mpxy.c (revision 30feb38ac712e49dfe9a7bc1d2db3b3458534e69)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright 2025 NXP
4  */
5 
6 #include <kernel/misc.h>
7 #include <mm/core_memprot.h>
8 #include <sbi.h>
9 #include <sbi_mpxy.h>
10 #include <string.h>
11 
12 /*
13  * struct mpxy_core_local - MPXY per-hart local context
14  * @shmem:       Virtual base address of MPXY shared memory
15  * @shmem_pa:    Physical base address of MPXY shared memory
16  * @shmem_active:Indicates whether shared memory is active for this hart
17  *
18  * Holds MPXY-related per-hart data required for message exchange via
19  * the SBI MPXY extension.
20  */
21 struct mpxy_core_local {
22 	void *shmem;
23 	paddr_t shmem_pa;
24 	bool shmem_active;
25 };
26 
27 static struct mpxy_core_local mpxy_core_local_array[CFG_TEE_CORE_NB_CORE];
28 
mpxy_get_core_local(void)29 static struct mpxy_core_local *mpxy_get_core_local(void)
30 {
31 	struct mpxy_core_local *mpxy = NULL;
32 	size_t pos = 0;
33 	uint32_t hart_id = 0;
34 
35 	assert((thread_get_exceptions() & THREAD_EXCP_ALL) == THREAD_EXCP_ALL);
36 
37 	pos = get_core_pos();
38 	hart_id = thread_get_hartid_by_hartindex(pos);
39 
40 	mpxy = &mpxy_core_local_array[hart_id];
41 
42 	return mpxy;
43 }
44 
45 /**
46  * sbi_mpxy_get_shmem_size - Retrieve the MPXY shared memory size
47  * @shmem_size: Pointer to store the shared memory size in bytes
48  *
49  * Makes an SBI call to query the shared memory size used for
50  * sending and receiving messages via the MPXY extension.
51  *
52  * Return: 0 on success, negative SBI error code on failure.
53  */
sbi_mpxy_get_shmem_size(unsigned long * shmem_size)54 int sbi_mpxy_get_shmem_size(unsigned long *shmem_size)
55 {
56 	struct sbiret sbiret = {};
57 
58 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_SHMEM_SIZE, 0, 0, 0,
59 			   0, 0, 0);
60 	if (sbiret.error) {
61 		EMSG("MPXY SBI call failed: error=%ld value=%ld", sbiret.error,
62 		     sbiret.value);
63 		return sbiret.error;
64 	}
65 
66 	if (shmem_size)
67 		*shmem_size = sbiret.value;
68 
69 	return SBI_SUCCESS;
70 }
71 
72 /**
73  * sbi_mpxy_set_shmem - Set up MPXY shared memory on the current hart
74  *
75  * Allocates and registers a 4 KiB shared memory region, aligned to 4 KiB,
76  * as required by the MPXY extension. This memory is used for sending and
77  * receiving messages. Registers the shared memory with the SBI MPXY extension.
78  *
79  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
80  */
sbi_mpxy_set_shmem(void)81 int sbi_mpxy_set_shmem(void)
82 {
83 	struct mpxy_core_local *mpxy = NULL;
84 	struct sbiret sbiret = {};
85 	void *shmem = NULL;
86 	uint32_t exceptions = 0;
87 	int ret = SBI_ERR_FAILURE;
88 
89 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
90 
91 	mpxy = mpxy_get_core_local();
92 	if (mpxy->shmem_active)
93 		goto out;
94 
95 	shmem = memalign(SMALL_PAGE_SIZE, SMALL_PAGE_SIZE);
96 	if (!shmem)
97 		goto out;
98 
99 	mpxy->shmem = shmem;
100 	mpxy->shmem_pa = virt_to_phys(shmem);
101 
102 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SET_SHMEM, mpxy->shmem_pa,
103 			   0, 0);
104 	if (sbiret.error) {
105 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
106 		free(shmem);
107 		ret = sbiret.error;
108 		goto out;
109 	}
110 
111 	mpxy->shmem_active = true;
112 
113 	ret = SBI_SUCCESS;
114 
115 out:
116 	thread_unmask_exceptions(exceptions);
117 	return ret;
118 }
119 
120 /**
121  * sbi_mpxy_get_channel_ids - Retrieve MPXY channel IDs
122  * @channel_count: Number of channels expected
123  * @channel_ids: Buffer to store the retrieved channel IDs
124  *
125  * Uses the SBI MPXY extension to query the list of available channel IDs
126  * into the provided buffer.
127  *
128  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
129  */
sbi_mpxy_get_channel_ids(uint32_t channel_count,uint32_t * channel_ids)130 int sbi_mpxy_get_channel_ids(uint32_t channel_count, uint32_t *channel_ids)
131 {
132 	struct mpxy_core_local *mpxy = NULL;
133 	struct sbi_mpxy_channel_ids_data *data = NULL;
134 	uint32_t remaining = 0;
135 	uint32_t returned = 0;
136 	uint32_t count = 0;
137 	uint32_t start_index = 0;
138 	struct sbiret sbiret = {};
139 	uint32_t exceptions = 0;
140 
141 	if (!channel_count || !channel_ids)
142 		return SBI_ERR_INVALID_PARAM;
143 
144 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
145 
146 	mpxy = mpxy_get_core_local();
147 
148 	if (!mpxy->shmem_active) {
149 		sbiret.error = SBI_ERR_NO_SHMEM;
150 		goto out;
151 	}
152 
153 	data = mpxy->shmem;
154 
155 	do {
156 		sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_CHANNEL_IDS,
157 				   start_index, 0, 0, 0, 0, 0);
158 		if (sbiret.error) {
159 			EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
160 			goto out;
161 		}
162 
163 		remaining = data->remaining;
164 		returned = data->returned;
165 
166 		count = returned < (channel_count - start_index) ?
167 				returned :
168 				(channel_count - start_index);
169 		memcpy(&channel_ids[start_index], data->channel_array,
170 		       count * sizeof(uint32_t));
171 		start_index += count;
172 	} while (remaining && start_index < channel_count);
173 
174 out:
175 	thread_unmask_exceptions(exceptions);
176 	return sbiret.error;
177 }
178 
179 /**
180  * sbi_mpxy_read_attributes - Read attributes from an MPXY channel
181  * @channel_id: ID of the channel
182  * @base_attribute_id: Starting attribute ID
183  * @attribute_count: Number of attributes to read
184  * @attribute_buf: Buffer to store the read attribute values
185  *
186  * Makes an SBI call to read attributes from the specified channel and copies
187  * the values from shared memory into the provided buffer.
188  *
189  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
190  */
sbi_mpxy_read_attributes(uint32_t channel_id,uint32_t base_attribute_id,uint32_t attribute_count,void * attribute_buf)191 int sbi_mpxy_read_attributes(uint32_t channel_id, uint32_t base_attribute_id,
192 			     uint32_t attribute_count, void *attribute_buf)
193 {
194 	struct mpxy_core_local *mpxy = NULL;
195 	struct sbiret sbiret = {};
196 	uint32_t exceptions = 0;
197 	int ret = SBI_ERR_FAILURE;
198 
199 	if (!attribute_count || !attribute_buf)
200 		return SBI_ERR_INVALID_PARAM;
201 
202 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
203 
204 	mpxy = mpxy_get_core_local();
205 
206 	if (!mpxy->shmem_active) {
207 		ret = SBI_ERR_NO_SHMEM;
208 		goto out;
209 	}
210 
211 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_READ_ATTRS, channel_id,
212 			   base_attribute_id, attribute_count, 0, 0, 0);
213 	if (!sbiret.error)
214 		memcpy(attribute_buf, (void *)mpxy->shmem,
215 		       attribute_count * sizeof(uint32_t));
216 	else
217 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
218 
219 	ret = sbiret.error;
220 out:
221 	thread_unmask_exceptions(exceptions);
222 	return ret;
223 }
224 
225 /**
226  * sbi_mpxy_write_attributes - Write attributes to an MPXY channel
227  * @channel_id: ID of the channel to write attributes to
228  * @base_attribute_id: Starting attribute ID
229  * @attribute_count: Number of attributes to write
230  * @attributes_buf: Buffer containing the attribute values
231  *
232  * Copies the attribute values into shared memory and makes an SBI call to
233  * write them to the specified channel.
234  *
235  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
236  */
sbi_mpxy_write_attributes(uint32_t channel_id,uint32_t base_attribute_id,uint32_t attribute_count,uint32_t * attributes_buf)237 int sbi_mpxy_write_attributes(uint32_t channel_id, uint32_t base_attribute_id,
238 			      uint32_t attribute_count,
239 			      uint32_t *attributes_buf)
240 {
241 	struct mpxy_core_local *mpxy = NULL;
242 	struct sbiret sbiret = {};
243 	uint32_t exceptions = 0;
244 	int ret = SBI_ERR_FAILURE;
245 
246 	if (!attribute_count || !attributes_buf)
247 		return SBI_ERR_INVALID_PARAM;
248 
249 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
250 
251 	mpxy = mpxy_get_core_local();
252 
253 	if (!mpxy->shmem_active) {
254 		ret = SBI_ERR_NO_SHMEM;
255 		goto out;
256 	}
257 
258 	memcpy(mpxy->shmem, attributes_buf, attribute_count * sizeof(uint32_t));
259 
260 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_WRITE_ATTRS, channel_id,
261 			   base_attribute_id, attribute_count, 0, 0, 0);
262 
263 	if (sbiret.error)
264 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
265 
266 	ret = sbiret.error;
267 out:
268 	thread_unmask_exceptions(exceptions);
269 	return ret;
270 }
271 
272 /**
273  * sbi_mpxy_send_message_with_response - Send a message and receive response
274  * via MPXY
275  * @channel_id: ID of the channel
276  * @message_id: ID of the message
277  * @message: Pointer to transmit buffer (can be NULL if message_len is 0)
278  * @message_len: Length of transmit buffer in bytes
279  * @response: Pointer to receive buffer
280  * @max_response_len: Maximum size of receive buffer in bytes
281  * @response_len: Pointer to store length of received data
282  *
283  * Copies transmit data into shared memory and makes an SBI call to send
284  * the message and receive a response. Copies the received response into
285  * the provided receive buffer.
286  *
287  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
288  */
sbi_mpxy_send_message_with_response(uint32_t channel_id,uint32_t message_id,void * message,unsigned long message_len,void * response,unsigned long max_response_len,unsigned long * response_len)289 int sbi_mpxy_send_message_with_response(uint32_t channel_id,
290 					uint32_t message_id, void *message,
291 					unsigned long message_len,
292 					void *response,
293 					unsigned long max_response_len,
294 					unsigned long *response_len)
295 {
296 	struct mpxy_core_local *mpxy = NULL;
297 	unsigned long response_bytes = 0;
298 	struct sbiret sbiret = {};
299 	uint32_t exceptions = 0;
300 	int ret = SBI_ERR_FAILURE;
301 
302 	if (!message && message_len)
303 		return SBI_ERR_INVALID_PARAM;
304 
305 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
306 
307 	mpxy = mpxy_get_core_local();
308 
309 	if (!mpxy->shmem_active) {
310 		ret = SBI_ERR_NO_SHMEM;
311 		goto out;
312 	}
313 
314 	if (message_len)
315 		memcpy(mpxy->shmem, message, message_len);
316 
317 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SEND_MSG_WITH_RESP,
318 			   channel_id, message_id, message_len, 0, 0, 0);
319 	if (response && !sbiret.error) {
320 		response_bytes = sbiret.value;
321 		if (response_bytes > max_response_len) {
322 			ret = SBI_ERR_INVALID_PARAM;
323 			goto out;
324 		}
325 
326 		memcpy(response, mpxy->shmem, response_bytes);
327 		if (response_len)
328 			*response_len = response_bytes;
329 	}
330 
331 	if (sbiret.error)
332 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
333 
334 	ret = sbiret.error;
335 out:
336 	thread_unmask_exceptions(exceptions);
337 	return ret;
338 }
339 
340 /**
341  * sbi_mpxy_send_message_without_response - Send a message via MPXY without
342  * expecting a response
343  * @channel_id: ID of the channel
344  * @message_id: Message ID
345  * @message: Pointer to transmit buffer (may be NULL if message_len is 0)
346  * @message_len: Number of bytes to send
347  *
348  * Copies transmit data into shared memory and makes an SBI call to send the
349  * message without waiting for a response.
350  *
351  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
352  */
sbi_mpxy_send_message_without_response(uint32_t channel_id,uint32_t message_id,void * message,unsigned long message_len)353 int sbi_mpxy_send_message_without_response(uint32_t channel_id,
354 					   uint32_t message_id, void *message,
355 					   unsigned long message_len)
356 {
357 	struct mpxy_core_local *mpxy = NULL;
358 	struct sbiret sbiret = {};
359 	uint32_t exceptions = 0;
360 	int ret = SBI_ERR_FAILURE;
361 
362 	if (!message && message_len)
363 		return SBI_ERR_INVALID_PARAM;
364 
365 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
366 
367 	mpxy = mpxy_get_core_local();
368 
369 	if (!mpxy->shmem_active) {
370 		ret = SBI_ERR_NO_SHMEM;
371 		goto out;
372 	}
373 
374 	if (message_len)
375 		memcpy(mpxy->shmem, message, message_len);
376 
377 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SEND_MSG_WITHOUT_RESP,
378 			   channel_id, message_id, message_len, 0, 0, 0);
379 
380 	if (sbiret.error)
381 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
382 
383 	ret = sbiret.error;
384 out:
385 	thread_unmask_exceptions(exceptions);
386 	return ret;
387 }
388 
389 /**
390  * sbi_mpxy_get_channel_count - Get the total number of MPXY channels
391  * @channel_count: Pointer to store the total number of channels
392  *
393  * Makes an SBI call to retrieve the number of channels by reading
394  * the remaining and returned fields from the shared memory structure.
395  *
396  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
397  */
sbi_mpxy_get_channel_count(uint32_t * channel_count)398 int sbi_mpxy_get_channel_count(uint32_t *channel_count)
399 {
400 	struct mpxy_core_local *mpxy = NULL;
401 	struct sbi_mpxy_channel_ids_data *data = NULL;
402 	uint32_t remaining = 0;
403 	uint32_t returned = 0;
404 	struct sbiret sbiret = {};
405 	uint32_t exceptions = 0;
406 	int ret = SBI_ERR_FAILURE;
407 
408 	if (!channel_count)
409 		return SBI_ERR_INVALID_PARAM;
410 
411 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
412 
413 	mpxy = mpxy_get_core_local();
414 
415 	if (!mpxy->shmem_active) {
416 		ret = SBI_ERR_NO_SHMEM;
417 		goto out;
418 	}
419 
420 	data = mpxy->shmem;
421 
422 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_CHANNEL_IDS, 0, 0, 0,
423 			   0, 0, 0);
424 	if (sbiret.error) {
425 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
426 		goto out;
427 	}
428 
429 	remaining = data->remaining;
430 	returned = data->returned;
431 	*channel_count = remaining + returned;
432 
433 	ret = sbiret.error;
434 
435 out:
436 	thread_unmask_exceptions(exceptions);
437 	return ret;
438 }
439 
440 /**
441  * sbi_mpxy_get_notification_events - Retrieve notification events from an
442  * MPXY channel
443  * @channel_id: ID of the channel
444  * @notif_data: Pointer to buffer to store notification data
445  * @events_data_len: Pointer to store length of events data in bytes
446  *
447  * Makes an SBI call to fetch notification events from the specified channel
448  * and copies them from shared memory into the provided buffer.
449  *
450  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
451  */
452 int
sbi_mpxy_get_notification_events(uint32_t channel_id,struct sbi_mpxy_notification_data * notif_data,unsigned long * events_data_len)453 sbi_mpxy_get_notification_events(uint32_t channel_id,
454 				 struct sbi_mpxy_notification_data *notif_data,
455 				 unsigned long *events_data_len)
456 {
457 	struct mpxy_core_local *mpxy = NULL;
458 	struct sbiret sbiret = {};
459 	uint32_t exceptions = 0;
460 	int ret = SBI_ERR_FAILURE;
461 
462 	if (!notif_data || !events_data_len)
463 		return SBI_ERR_INVALID_PARAM;
464 
465 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
466 
467 	mpxy = mpxy_get_core_local();
468 
469 	if (!mpxy->shmem_active) {
470 		ret = SBI_ERR_NO_SHMEM;
471 		goto out;
472 	}
473 
474 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_NOTIFICATION_EVENTS,
475 			   channel_id, 0, 0, 0, 0, 0);
476 	if (sbiret.error) {
477 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
478 		ret = sbiret.error;
479 		goto out;
480 	}
481 
482 	memcpy(notif_data, mpxy->shmem, sbiret.value + 16);
483 	*events_data_len = sbiret.value;
484 
485 	ret = sbiret.error;
486 
487 out:
488 	thread_unmask_exceptions(exceptions);
489 	return ret;
490 }
491