xref: /optee_os/core/arch/riscv/kernel/sbi_mpxy.c (revision 2ac77846aae2186e5ae2422cf3796166737d10b8)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright 2025 NXP
4  */
5 
6 #include <kernel/misc.h>
7 #include <mm/core_memprot.h>
8 #include <sbi.h>
9 #include <sbi_mpxy.h>
10 #include <string.h>
11 
12 /*
13  * struct mpxy_core_local - MPXY per-hart local context
14  * @shmem:       Virtual base address of MPXY shared memory
15  * @shmem_pa:    Physical base address of MPXY shared memory
16  * @shmem_active:Indicates whether shared memory is active for this hart
17  *
18  * Holds MPXY-related per-hart data required for message exchange via
19  * the SBI MPXY extension.
20  */
21 struct mpxy_core_local {
22 	void *shmem;
23 	paddr_t shmem_pa;
24 	bool shmem_active;
25 };
26 
27 static struct mpxy_core_local mpxy_core_local_array[CFG_TEE_CORE_NB_CORE];
28 
mpxy_get_core_local(void)29 static struct mpxy_core_local *mpxy_get_core_local(void)
30 {
31 	struct mpxy_core_local *mpxy = NULL;
32 	uint32_t hart_id = 0;
33 
34 	assert((thread_get_exceptions() & THREAD_EXCP_ALL) == THREAD_EXCP_ALL);
35 
36 	hart_id = thread_get_hartid();
37 
38 	mpxy = &mpxy_core_local_array[hart_id];
39 
40 	return mpxy;
41 }
42 
43 /**
44  * sbi_mpxy_get_shmem_size - Retrieve the MPXY shared memory size
45  * @shmem_size: Pointer to store the shared memory size in bytes
46  *
47  * Makes an SBI call to query the shared memory size used for
48  * sending and receiving messages via the MPXY extension.
49  *
50  * Return: 0 on success, negative SBI error code on failure.
51  */
sbi_mpxy_get_shmem_size(unsigned long * shmem_size)52 int sbi_mpxy_get_shmem_size(unsigned long *shmem_size)
53 {
54 	struct sbiret sbiret = {};
55 
56 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_SHMEM_SIZE, 0, 0, 0,
57 			   0, 0, 0);
58 	if (sbiret.error) {
59 		EMSG("MPXY SBI call failed: error=%ld value=%ld", sbiret.error,
60 		     sbiret.value);
61 		return sbiret.error;
62 	}
63 
64 	if (shmem_size)
65 		*shmem_size = sbiret.value;
66 
67 	return SBI_SUCCESS;
68 }
69 
70 /**
71  * sbi_mpxy_set_shmem - Set up MPXY shared memory on the current hart
72  *
73  * Allocates and registers a 4 KiB shared memory region, aligned to 4 KiB,
74  * as required by the MPXY extension. This memory is used for sending and
75  * receiving messages. Registers the shared memory with the SBI MPXY extension.
76  *
77  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
78  */
sbi_mpxy_set_shmem(void)79 int sbi_mpxy_set_shmem(void)
80 {
81 	struct mpxy_core_local *mpxy = NULL;
82 	struct sbiret sbiret = {};
83 	void *shmem = NULL;
84 	uint32_t exceptions = 0;
85 	int ret = SBI_ERR_FAILURE;
86 
87 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
88 
89 	mpxy = mpxy_get_core_local();
90 	if (mpxy->shmem_active)
91 		goto out;
92 
93 	shmem = memalign(SMALL_PAGE_SIZE, SMALL_PAGE_SIZE);
94 	if (!shmem)
95 		goto out;
96 
97 	mpxy->shmem = shmem;
98 	mpxy->shmem_pa = virt_to_phys(shmem);
99 
100 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SET_SHMEM, mpxy->shmem_pa,
101 			   0, 0);
102 	if (sbiret.error) {
103 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
104 		free(shmem);
105 		ret = sbiret.error;
106 		goto out;
107 	}
108 
109 	mpxy->shmem_active = true;
110 
111 	ret = SBI_SUCCESS;
112 
113 out:
114 	thread_unmask_exceptions(exceptions);
115 	return ret;
116 }
117 
118 /**
119  * sbi_mpxy_get_channel_ids - Retrieve MPXY channel IDs
120  * @channel_count: Number of channels expected
121  * @channel_ids: Buffer to store the retrieved channel IDs
122  *
123  * Uses the SBI MPXY extension to query the list of available channel IDs
124  * into the provided buffer.
125  *
126  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
127  */
sbi_mpxy_get_channel_ids(uint32_t channel_count,uint32_t * channel_ids)128 int sbi_mpxy_get_channel_ids(uint32_t channel_count, uint32_t *channel_ids)
129 {
130 	struct mpxy_core_local *mpxy = NULL;
131 	struct sbi_mpxy_channel_ids_data *data = NULL;
132 	uint32_t remaining = 0;
133 	uint32_t returned = 0;
134 	uint32_t count = 0;
135 	uint32_t start_index = 0;
136 	struct sbiret sbiret = {};
137 	uint32_t exceptions = 0;
138 
139 	if (!channel_count || !channel_ids)
140 		return SBI_ERR_INVALID_PARAM;
141 
142 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
143 
144 	mpxy = mpxy_get_core_local();
145 
146 	if (!mpxy->shmem_active) {
147 		sbiret.error = SBI_ERR_NO_SHMEM;
148 		goto out;
149 	}
150 
151 	data = mpxy->shmem;
152 
153 	do {
154 		sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_CHANNEL_IDS,
155 				   start_index, 0, 0, 0, 0, 0);
156 		if (sbiret.error) {
157 			EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
158 			goto out;
159 		}
160 
161 		remaining = data->remaining;
162 		returned = data->returned;
163 
164 		count = returned < (channel_count - start_index) ?
165 				returned :
166 				(channel_count - start_index);
167 		memcpy(&channel_ids[start_index], data->channel_array,
168 		       count * sizeof(uint32_t));
169 		start_index += count;
170 	} while (remaining && start_index < channel_count);
171 
172 out:
173 	thread_unmask_exceptions(exceptions);
174 	return sbiret.error;
175 }
176 
177 /**
178  * sbi_mpxy_read_attributes - Read attributes from an MPXY channel
179  * @channel_id: ID of the channel
180  * @base_attribute_id: Starting attribute ID
181  * @attribute_count: Number of attributes to read
182  * @attribute_buf: Buffer to store the read attribute values
183  *
184  * Makes an SBI call to read attributes from the specified channel and copies
185  * the values from shared memory into the provided buffer.
186  *
187  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
188  */
sbi_mpxy_read_attributes(uint32_t channel_id,uint32_t base_attribute_id,uint32_t attribute_count,void * attribute_buf)189 int sbi_mpxy_read_attributes(uint32_t channel_id, uint32_t base_attribute_id,
190 			     uint32_t attribute_count, void *attribute_buf)
191 {
192 	struct mpxy_core_local *mpxy = NULL;
193 	struct sbiret sbiret = {};
194 	uint32_t exceptions = 0;
195 	int ret = SBI_ERR_FAILURE;
196 
197 	if (!attribute_count || !attribute_buf)
198 		return SBI_ERR_INVALID_PARAM;
199 
200 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
201 
202 	mpxy = mpxy_get_core_local();
203 
204 	if (!mpxy->shmem_active) {
205 		ret = SBI_ERR_NO_SHMEM;
206 		goto out;
207 	}
208 
209 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_READ_ATTRS, channel_id,
210 			   base_attribute_id, attribute_count, 0, 0, 0);
211 	if (!sbiret.error)
212 		memcpy(attribute_buf, (void *)mpxy->shmem,
213 		       attribute_count * sizeof(uint32_t));
214 	else
215 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
216 
217 	ret = sbiret.error;
218 out:
219 	thread_unmask_exceptions(exceptions);
220 	return ret;
221 }
222 
223 /**
224  * sbi_mpxy_write_attributes - Write attributes to an MPXY channel
225  * @channel_id: ID of the channel to write attributes to
226  * @base_attribute_id: Starting attribute ID
227  * @attribute_count: Number of attributes to write
228  * @attributes_buf: Buffer containing the attribute values
229  *
230  * Copies the attribute values into shared memory and makes an SBI call to
231  * write them to the specified channel.
232  *
233  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
234  */
sbi_mpxy_write_attributes(uint32_t channel_id,uint32_t base_attribute_id,uint32_t attribute_count,uint32_t * attributes_buf)235 int sbi_mpxy_write_attributes(uint32_t channel_id, uint32_t base_attribute_id,
236 			      uint32_t attribute_count,
237 			      uint32_t *attributes_buf)
238 {
239 	struct mpxy_core_local *mpxy = NULL;
240 	struct sbiret sbiret = {};
241 	uint32_t exceptions = 0;
242 	int ret = SBI_ERR_FAILURE;
243 
244 	if (!attribute_count || !attributes_buf)
245 		return SBI_ERR_INVALID_PARAM;
246 
247 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
248 
249 	mpxy = mpxy_get_core_local();
250 
251 	if (!mpxy->shmem_active) {
252 		ret = SBI_ERR_NO_SHMEM;
253 		goto out;
254 	}
255 
256 	memcpy(mpxy->shmem, attributes_buf, attribute_count * sizeof(uint32_t));
257 
258 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_WRITE_ATTRS, channel_id,
259 			   base_attribute_id, attribute_count, 0, 0, 0);
260 
261 	if (sbiret.error)
262 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
263 
264 	ret = sbiret.error;
265 out:
266 	thread_unmask_exceptions(exceptions);
267 	return ret;
268 }
269 
270 /**
271  * sbi_mpxy_send_message_with_response - Send a message and receive response
272  * via MPXY
273  * @channel_id: ID of the channel
274  * @message_id: ID of the message
275  * @message: Pointer to transmit buffer (can be NULL if message_len is 0)
276  * @message_len: Length of transmit buffer in bytes
277  * @response: Pointer to receive buffer
278  * @max_response_len: Maximum size of receive buffer in bytes
279  * @response_len: Pointer to store length of received data
280  *
281  * Copies transmit data into shared memory and makes an SBI call to send
282  * the message and receive a response. Copies the received response into
283  * the provided receive buffer.
284  *
285  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
286  */
sbi_mpxy_send_message_with_response(uint32_t channel_id,uint32_t message_id,void * message,unsigned long message_len,void * response,unsigned long max_response_len,unsigned long * response_len)287 int sbi_mpxy_send_message_with_response(uint32_t channel_id,
288 					uint32_t message_id, void *message,
289 					unsigned long message_len,
290 					void *response,
291 					unsigned long max_response_len,
292 					unsigned long *response_len)
293 {
294 	struct mpxy_core_local *mpxy = NULL;
295 	unsigned long response_bytes = 0;
296 	struct sbiret sbiret = {};
297 	uint32_t exceptions = 0;
298 	int ret = SBI_ERR_FAILURE;
299 
300 	if (!message && message_len)
301 		return SBI_ERR_INVALID_PARAM;
302 
303 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
304 
305 	mpxy = mpxy_get_core_local();
306 
307 	if (!mpxy->shmem_active) {
308 		ret = SBI_ERR_NO_SHMEM;
309 		goto out;
310 	}
311 
312 	if (message_len)
313 		memcpy(mpxy->shmem, message, message_len);
314 
315 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SEND_MSG_WITH_RESP,
316 			   channel_id, message_id, message_len, 0, 0, 0);
317 	if (response && !sbiret.error) {
318 		response_bytes = sbiret.value;
319 		if (response_bytes > max_response_len) {
320 			ret = SBI_ERR_INVALID_PARAM;
321 			goto out;
322 		}
323 
324 		memcpy(response, mpxy->shmem, response_bytes);
325 		if (response_len)
326 			*response_len = response_bytes;
327 	}
328 
329 	if (sbiret.error)
330 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
331 
332 	ret = sbiret.error;
333 out:
334 	thread_unmask_exceptions(exceptions);
335 	return ret;
336 }
337 
338 /**
339  * sbi_mpxy_send_message_without_response - Send a message via MPXY without
340  * expecting a response
341  * @channel_id: ID of the channel
342  * @message_id: Message ID
343  * @message: Pointer to transmit buffer (may be NULL if message_len is 0)
344  * @message_len: Number of bytes to send
345  *
346  * Copies transmit data into shared memory and makes an SBI call to send the
347  * message without waiting for a response.
348  *
349  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
350  */
sbi_mpxy_send_message_without_response(uint32_t channel_id,uint32_t message_id,void * message,unsigned long message_len)351 int sbi_mpxy_send_message_without_response(uint32_t channel_id,
352 					   uint32_t message_id, void *message,
353 					   unsigned long message_len)
354 {
355 	struct mpxy_core_local *mpxy = NULL;
356 	struct sbiret sbiret = {};
357 	uint32_t exceptions = 0;
358 	int ret = SBI_ERR_FAILURE;
359 
360 	if (!message && message_len)
361 		return SBI_ERR_INVALID_PARAM;
362 
363 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
364 
365 	mpxy = mpxy_get_core_local();
366 
367 	if (!mpxy->shmem_active) {
368 		ret = SBI_ERR_NO_SHMEM;
369 		goto out;
370 	}
371 
372 	if (message_len)
373 		memcpy(mpxy->shmem, message, message_len);
374 
375 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SEND_MSG_WITHOUT_RESP,
376 			   channel_id, message_id, message_len, 0, 0, 0);
377 
378 	if (sbiret.error)
379 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
380 
381 	ret = sbiret.error;
382 out:
383 	thread_unmask_exceptions(exceptions);
384 	return ret;
385 }
386 
387 /**
388  * sbi_mpxy_get_channel_count - Get the total number of MPXY channels
389  * @channel_count: Pointer to store the total number of channels
390  *
391  * Makes an SBI call to retrieve the number of channels by reading
392  * the remaining and returned fields from the shared memory structure.
393  *
394  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
395  */
sbi_mpxy_get_channel_count(uint32_t * channel_count)396 int sbi_mpxy_get_channel_count(uint32_t *channel_count)
397 {
398 	struct mpxy_core_local *mpxy = NULL;
399 	struct sbi_mpxy_channel_ids_data *data = NULL;
400 	uint32_t remaining = 0;
401 	uint32_t returned = 0;
402 	struct sbiret sbiret = {};
403 	uint32_t exceptions = 0;
404 	int ret = SBI_ERR_FAILURE;
405 
406 	if (!channel_count)
407 		return SBI_ERR_INVALID_PARAM;
408 
409 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
410 
411 	mpxy = mpxy_get_core_local();
412 
413 	if (!mpxy->shmem_active) {
414 		ret = SBI_ERR_NO_SHMEM;
415 		goto out;
416 	}
417 
418 	data = mpxy->shmem;
419 
420 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_CHANNEL_IDS, 0, 0, 0,
421 			   0, 0, 0);
422 	if (sbiret.error) {
423 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
424 		goto out;
425 	}
426 
427 	remaining = data->remaining;
428 	returned = data->returned;
429 	*channel_count = remaining + returned;
430 
431 	ret = sbiret.error;
432 
433 out:
434 	thread_unmask_exceptions(exceptions);
435 	return ret;
436 }
437 
438 /**
439  * sbi_mpxy_get_notification_events - Retrieve notification events from an
440  * MPXY channel
441  * @channel_id: ID of the channel
442  * @notif_data: Pointer to buffer to store notification data
443  * @events_data_len: Pointer to store length of events data in bytes
444  *
445  * Makes an SBI call to fetch notification events from the specified channel
446  * and copies them from shared memory into the provided buffer.
447  *
448  * Return: SBI_SUCCESS on success, negative SBI error code on failure.
449  */
450 int
sbi_mpxy_get_notification_events(uint32_t channel_id,struct sbi_mpxy_notification_data * notif_data,unsigned long * events_data_len)451 sbi_mpxy_get_notification_events(uint32_t channel_id,
452 				 struct sbi_mpxy_notification_data *notif_data,
453 				 unsigned long *events_data_len)
454 {
455 	struct mpxy_core_local *mpxy = NULL;
456 	struct sbiret sbiret = {};
457 	uint32_t exceptions = 0;
458 	int ret = SBI_ERR_FAILURE;
459 
460 	if (!notif_data || !events_data_len)
461 		return SBI_ERR_INVALID_PARAM;
462 
463 	exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
464 
465 	mpxy = mpxy_get_core_local();
466 
467 	if (!mpxy->shmem_active) {
468 		ret = SBI_ERR_NO_SHMEM;
469 		goto out;
470 	}
471 
472 	sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_NOTIFICATION_EVENTS,
473 			   channel_id, 0, 0, 0, 0, 0);
474 	if (sbiret.error) {
475 		EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
476 		ret = sbiret.error;
477 		goto out;
478 	}
479 
480 	memcpy(notif_data, mpxy->shmem, sbiret.value + 16);
481 	*events_data_len = sbiret.value;
482 
483 	ret = sbiret.error;
484 
485 out:
486 	thread_unmask_exceptions(exceptions);
487 	return ret;
488 }
489