1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3 * Copyright 2025 NXP
4 */
5
6 #include <kernel/misc.h>
7 #include <mm/core_memprot.h>
8 #include <sbi.h>
9 #include <sbi_mpxy.h>
10 #include <string.h>
11
12 /*
13 * struct mpxy_core_local - MPXY per-hart local context
14 * @shmem: Virtual base address of MPXY shared memory
15 * @shmem_pa: Physical base address of MPXY shared memory
16 * @shmem_active:Indicates whether shared memory is active for this hart
17 *
18 * Holds MPXY-related per-hart data required for message exchange via
19 * the SBI MPXY extension.
20 */
21 struct mpxy_core_local {
22 void *shmem;
23 paddr_t shmem_pa;
24 bool shmem_active;
25 };
26
27 static struct mpxy_core_local mpxy_core_local_array[CFG_TEE_CORE_NB_CORE];
28
mpxy_get_core_local(void)29 static struct mpxy_core_local *mpxy_get_core_local(void)
30 {
31 struct mpxy_core_local *mpxy = NULL;
32 uint32_t hart_id = 0;
33
34 assert((thread_get_exceptions() & THREAD_EXCP_ALL) == THREAD_EXCP_ALL);
35
36 hart_id = thread_get_hartid();
37
38 mpxy = &mpxy_core_local_array[hart_id];
39
40 return mpxy;
41 }
42
43 /**
44 * sbi_mpxy_get_shmem_size - Retrieve the MPXY shared memory size
45 * @shmem_size: Pointer to store the shared memory size in bytes
46 *
47 * Makes an SBI call to query the shared memory size used for
48 * sending and receiving messages via the MPXY extension.
49 *
50 * Return: 0 on success, negative SBI error code on failure.
51 */
sbi_mpxy_get_shmem_size(unsigned long * shmem_size)52 int sbi_mpxy_get_shmem_size(unsigned long *shmem_size)
53 {
54 struct sbiret sbiret = {};
55
56 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_SHMEM_SIZE, 0, 0, 0,
57 0, 0, 0);
58 if (sbiret.error) {
59 EMSG("MPXY SBI call failed: error=%ld value=%ld", sbiret.error,
60 sbiret.value);
61 return sbiret.error;
62 }
63
64 if (shmem_size)
65 *shmem_size = sbiret.value;
66
67 return SBI_SUCCESS;
68 }
69
70 /**
71 * sbi_mpxy_set_shmem - Set up MPXY shared memory on the current hart
72 *
73 * Allocates and registers a 4 KiB shared memory region, aligned to 4 KiB,
74 * as required by the MPXY extension. This memory is used for sending and
75 * receiving messages. Registers the shared memory with the SBI MPXY extension.
76 *
77 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
78 */
sbi_mpxy_set_shmem(void)79 int sbi_mpxy_set_shmem(void)
80 {
81 struct mpxy_core_local *mpxy = NULL;
82 struct sbiret sbiret = {};
83 void *shmem = NULL;
84 uint32_t exceptions = 0;
85 int ret = SBI_ERR_FAILURE;
86
87 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
88
89 mpxy = mpxy_get_core_local();
90 if (mpxy->shmem_active)
91 goto out;
92
93 shmem = memalign(SMALL_PAGE_SIZE, SMALL_PAGE_SIZE);
94 if (!shmem)
95 goto out;
96
97 mpxy->shmem = shmem;
98 mpxy->shmem_pa = virt_to_phys(shmem);
99
100 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SET_SHMEM, mpxy->shmem_pa,
101 0, 0);
102 if (sbiret.error) {
103 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
104 free(shmem);
105 ret = sbiret.error;
106 goto out;
107 }
108
109 mpxy->shmem_active = true;
110
111 ret = SBI_SUCCESS;
112
113 out:
114 thread_unmask_exceptions(exceptions);
115 return ret;
116 }
117
118 /**
119 * sbi_mpxy_get_channel_ids - Retrieve MPXY channel IDs
120 * @channel_count: Number of channels expected
121 * @channel_ids: Buffer to store the retrieved channel IDs
122 *
123 * Uses the SBI MPXY extension to query the list of available channel IDs
124 * into the provided buffer.
125 *
126 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
127 */
sbi_mpxy_get_channel_ids(uint32_t channel_count,uint32_t * channel_ids)128 int sbi_mpxy_get_channel_ids(uint32_t channel_count, uint32_t *channel_ids)
129 {
130 struct mpxy_core_local *mpxy = NULL;
131 struct sbi_mpxy_channel_ids_data *data = NULL;
132 uint32_t remaining = 0;
133 uint32_t returned = 0;
134 uint32_t count = 0;
135 uint32_t start_index = 0;
136 struct sbiret sbiret = {};
137 uint32_t exceptions = 0;
138
139 if (!channel_count || !channel_ids)
140 return SBI_ERR_INVALID_PARAM;
141
142 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
143
144 mpxy = mpxy_get_core_local();
145
146 if (!mpxy->shmem_active) {
147 sbiret.error = SBI_ERR_NO_SHMEM;
148 goto out;
149 }
150
151 data = mpxy->shmem;
152
153 do {
154 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_CHANNEL_IDS,
155 start_index, 0, 0, 0, 0, 0);
156 if (sbiret.error) {
157 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
158 goto out;
159 }
160
161 remaining = data->remaining;
162 returned = data->returned;
163
164 count = returned < (channel_count - start_index) ?
165 returned :
166 (channel_count - start_index);
167 memcpy(&channel_ids[start_index], data->channel_array,
168 count * sizeof(uint32_t));
169 start_index += count;
170 } while (remaining && start_index < channel_count);
171
172 out:
173 thread_unmask_exceptions(exceptions);
174 return sbiret.error;
175 }
176
177 /**
178 * sbi_mpxy_read_attributes - Read attributes from an MPXY channel
179 * @channel_id: ID of the channel
180 * @base_attribute_id: Starting attribute ID
181 * @attribute_count: Number of attributes to read
182 * @attribute_buf: Buffer to store the read attribute values
183 *
184 * Makes an SBI call to read attributes from the specified channel and copies
185 * the values from shared memory into the provided buffer.
186 *
187 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
188 */
sbi_mpxy_read_attributes(uint32_t channel_id,uint32_t base_attribute_id,uint32_t attribute_count,void * attribute_buf)189 int sbi_mpxy_read_attributes(uint32_t channel_id, uint32_t base_attribute_id,
190 uint32_t attribute_count, void *attribute_buf)
191 {
192 struct mpxy_core_local *mpxy = NULL;
193 struct sbiret sbiret = {};
194 uint32_t exceptions = 0;
195 int ret = SBI_ERR_FAILURE;
196
197 if (!attribute_count || !attribute_buf)
198 return SBI_ERR_INVALID_PARAM;
199
200 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
201
202 mpxy = mpxy_get_core_local();
203
204 if (!mpxy->shmem_active) {
205 ret = SBI_ERR_NO_SHMEM;
206 goto out;
207 }
208
209 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_READ_ATTRS, channel_id,
210 base_attribute_id, attribute_count, 0, 0, 0);
211 if (!sbiret.error)
212 memcpy(attribute_buf, (void *)mpxy->shmem,
213 attribute_count * sizeof(uint32_t));
214 else
215 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
216
217 ret = sbiret.error;
218 out:
219 thread_unmask_exceptions(exceptions);
220 return ret;
221 }
222
223 /**
224 * sbi_mpxy_write_attributes - Write attributes to an MPXY channel
225 * @channel_id: ID of the channel to write attributes to
226 * @base_attribute_id: Starting attribute ID
227 * @attribute_count: Number of attributes to write
228 * @attributes_buf: Buffer containing the attribute values
229 *
230 * Copies the attribute values into shared memory and makes an SBI call to
231 * write them to the specified channel.
232 *
233 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
234 */
sbi_mpxy_write_attributes(uint32_t channel_id,uint32_t base_attribute_id,uint32_t attribute_count,uint32_t * attributes_buf)235 int sbi_mpxy_write_attributes(uint32_t channel_id, uint32_t base_attribute_id,
236 uint32_t attribute_count,
237 uint32_t *attributes_buf)
238 {
239 struct mpxy_core_local *mpxy = NULL;
240 struct sbiret sbiret = {};
241 uint32_t exceptions = 0;
242 int ret = SBI_ERR_FAILURE;
243
244 if (!attribute_count || !attributes_buf)
245 return SBI_ERR_INVALID_PARAM;
246
247 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
248
249 mpxy = mpxy_get_core_local();
250
251 if (!mpxy->shmem_active) {
252 ret = SBI_ERR_NO_SHMEM;
253 goto out;
254 }
255
256 memcpy(mpxy->shmem, attributes_buf, attribute_count * sizeof(uint32_t));
257
258 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_WRITE_ATTRS, channel_id,
259 base_attribute_id, attribute_count, 0, 0, 0);
260
261 if (sbiret.error)
262 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
263
264 ret = sbiret.error;
265 out:
266 thread_unmask_exceptions(exceptions);
267 return ret;
268 }
269
270 /**
271 * sbi_mpxy_send_message_with_response - Send a message and receive response
272 * via MPXY
273 * @channel_id: ID of the channel
274 * @message_id: ID of the message
275 * @message: Pointer to transmit buffer (can be NULL if message_len is 0)
276 * @message_len: Length of transmit buffer in bytes
277 * @response: Pointer to receive buffer
278 * @max_response_len: Maximum size of receive buffer in bytes
279 * @response_len: Pointer to store length of received data
280 *
281 * Copies transmit data into shared memory and makes an SBI call to send
282 * the message and receive a response. Copies the received response into
283 * the provided receive buffer.
284 *
285 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
286 */
sbi_mpxy_send_message_with_response(uint32_t channel_id,uint32_t message_id,void * message,unsigned long message_len,void * response,unsigned long max_response_len,unsigned long * response_len)287 int sbi_mpxy_send_message_with_response(uint32_t channel_id,
288 uint32_t message_id, void *message,
289 unsigned long message_len,
290 void *response,
291 unsigned long max_response_len,
292 unsigned long *response_len)
293 {
294 struct mpxy_core_local *mpxy = NULL;
295 unsigned long response_bytes = 0;
296 struct sbiret sbiret = {};
297 uint32_t exceptions = 0;
298 int ret = SBI_ERR_FAILURE;
299
300 if (!message && message_len)
301 return SBI_ERR_INVALID_PARAM;
302
303 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
304
305 mpxy = mpxy_get_core_local();
306
307 if (!mpxy->shmem_active) {
308 ret = SBI_ERR_NO_SHMEM;
309 goto out;
310 }
311
312 if (message_len)
313 memcpy(mpxy->shmem, message, message_len);
314
315 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SEND_MSG_WITH_RESP,
316 channel_id, message_id, message_len, 0, 0, 0);
317 if (response && !sbiret.error) {
318 response_bytes = sbiret.value;
319 if (response_bytes > max_response_len) {
320 ret = SBI_ERR_INVALID_PARAM;
321 goto out;
322 }
323
324 memcpy(response, mpxy->shmem, response_bytes);
325 if (response_len)
326 *response_len = response_bytes;
327 }
328
329 if (sbiret.error)
330 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
331
332 ret = sbiret.error;
333 out:
334 thread_unmask_exceptions(exceptions);
335 return ret;
336 }
337
338 /**
339 * sbi_mpxy_send_message_without_response - Send a message via MPXY without
340 * expecting a response
341 * @channel_id: ID of the channel
342 * @message_id: Message ID
343 * @message: Pointer to transmit buffer (may be NULL if message_len is 0)
344 * @message_len: Number of bytes to send
345 *
346 * Copies transmit data into shared memory and makes an SBI call to send the
347 * message without waiting for a response.
348 *
349 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
350 */
sbi_mpxy_send_message_without_response(uint32_t channel_id,uint32_t message_id,void * message,unsigned long message_len)351 int sbi_mpxy_send_message_without_response(uint32_t channel_id,
352 uint32_t message_id, void *message,
353 unsigned long message_len)
354 {
355 struct mpxy_core_local *mpxy = NULL;
356 struct sbiret sbiret = {};
357 uint32_t exceptions = 0;
358 int ret = SBI_ERR_FAILURE;
359
360 if (!message && message_len)
361 return SBI_ERR_INVALID_PARAM;
362
363 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
364
365 mpxy = mpxy_get_core_local();
366
367 if (!mpxy->shmem_active) {
368 ret = SBI_ERR_NO_SHMEM;
369 goto out;
370 }
371
372 if (message_len)
373 memcpy(mpxy->shmem, message, message_len);
374
375 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_SEND_MSG_WITHOUT_RESP,
376 channel_id, message_id, message_len, 0, 0, 0);
377
378 if (sbiret.error)
379 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
380
381 ret = sbiret.error;
382 out:
383 thread_unmask_exceptions(exceptions);
384 return ret;
385 }
386
387 /**
388 * sbi_mpxy_get_channel_count - Get the total number of MPXY channels
389 * @channel_count: Pointer to store the total number of channels
390 *
391 * Makes an SBI call to retrieve the number of channels by reading
392 * the remaining and returned fields from the shared memory structure.
393 *
394 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
395 */
sbi_mpxy_get_channel_count(uint32_t * channel_count)396 int sbi_mpxy_get_channel_count(uint32_t *channel_count)
397 {
398 struct mpxy_core_local *mpxy = NULL;
399 struct sbi_mpxy_channel_ids_data *data = NULL;
400 uint32_t remaining = 0;
401 uint32_t returned = 0;
402 struct sbiret sbiret = {};
403 uint32_t exceptions = 0;
404 int ret = SBI_ERR_FAILURE;
405
406 if (!channel_count)
407 return SBI_ERR_INVALID_PARAM;
408
409 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
410
411 mpxy = mpxy_get_core_local();
412
413 if (!mpxy->shmem_active) {
414 ret = SBI_ERR_NO_SHMEM;
415 goto out;
416 }
417
418 data = mpxy->shmem;
419
420 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_CHANNEL_IDS, 0, 0, 0,
421 0, 0, 0);
422 if (sbiret.error) {
423 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
424 goto out;
425 }
426
427 remaining = data->remaining;
428 returned = data->returned;
429 *channel_count = remaining + returned;
430
431 ret = sbiret.error;
432
433 out:
434 thread_unmask_exceptions(exceptions);
435 return ret;
436 }
437
438 /**
439 * sbi_mpxy_get_notification_events - Retrieve notification events from an
440 * MPXY channel
441 * @channel_id: ID of the channel
442 * @notif_data: Pointer to buffer to store notification data
443 * @events_data_len: Pointer to store length of events data in bytes
444 *
445 * Makes an SBI call to fetch notification events from the specified channel
446 * and copies them from shared memory into the provided buffer.
447 *
448 * Return: SBI_SUCCESS on success, negative SBI error code on failure.
449 */
450 int
sbi_mpxy_get_notification_events(uint32_t channel_id,struct sbi_mpxy_notification_data * notif_data,unsigned long * events_data_len)451 sbi_mpxy_get_notification_events(uint32_t channel_id,
452 struct sbi_mpxy_notification_data *notif_data,
453 unsigned long *events_data_len)
454 {
455 struct mpxy_core_local *mpxy = NULL;
456 struct sbiret sbiret = {};
457 uint32_t exceptions = 0;
458 int ret = SBI_ERR_FAILURE;
459
460 if (!notif_data || !events_data_len)
461 return SBI_ERR_INVALID_PARAM;
462
463 exceptions = thread_mask_exceptions(THREAD_EXCP_ALL);
464
465 mpxy = mpxy_get_core_local();
466
467 if (!mpxy->shmem_active) {
468 ret = SBI_ERR_NO_SHMEM;
469 goto out;
470 }
471
472 sbiret = sbi_ecall(SBI_EXT_MPXY, SBI_EXT_MPXY_GET_NOTIFICATION_EVENTS,
473 channel_id, 0, 0, 0, 0, 0);
474 if (sbiret.error) {
475 EMSG("MPXY SBI call failed: error=%ld", sbiret.error);
476 ret = sbiret.error;
477 goto out;
478 }
479
480 memcpy(notif_data, mpxy->shmem, sbiret.value + 16);
481 *events_data_len = sbiret.value;
482
483 ret = sbiret.error;
484
485 out:
486 thread_unmask_exceptions(exceptions);
487 return ret;
488 }
489