1*4882a593Smuzhiyun /* SPDX-License-Identifier: GPL-2.0 */
2*4882a593Smuzhiyun #ifndef _LINUX_MIN_HEAP_H
3*4882a593Smuzhiyun #define _LINUX_MIN_HEAP_H
4*4882a593Smuzhiyun
5*4882a593Smuzhiyun #include <linux/bug.h>
6*4882a593Smuzhiyun #include <linux/string.h>
7*4882a593Smuzhiyun #include <linux/types.h>
8*4882a593Smuzhiyun
9*4882a593Smuzhiyun /**
10*4882a593Smuzhiyun * struct min_heap - Data structure to hold a min-heap.
11*4882a593Smuzhiyun * @data: Start of array holding the heap elements.
12*4882a593Smuzhiyun * @nr: Number of elements currently in the heap.
13*4882a593Smuzhiyun * @size: Maximum number of elements that can be held in current storage.
14*4882a593Smuzhiyun */
15*4882a593Smuzhiyun struct min_heap {
16*4882a593Smuzhiyun void *data;
17*4882a593Smuzhiyun int nr;
18*4882a593Smuzhiyun int size;
19*4882a593Smuzhiyun };
20*4882a593Smuzhiyun
21*4882a593Smuzhiyun /**
22*4882a593Smuzhiyun * struct min_heap_callbacks - Data/functions to customise the min_heap.
23*4882a593Smuzhiyun * @elem_size: The nr of each element in bytes.
24*4882a593Smuzhiyun * @less: Partial order function for this heap.
25*4882a593Smuzhiyun * @swp: Swap elements function.
26*4882a593Smuzhiyun */
27*4882a593Smuzhiyun struct min_heap_callbacks {
28*4882a593Smuzhiyun int elem_size;
29*4882a593Smuzhiyun bool (*less)(const void *lhs, const void *rhs);
30*4882a593Smuzhiyun void (*swp)(void *lhs, void *rhs);
31*4882a593Smuzhiyun };
32*4882a593Smuzhiyun
33*4882a593Smuzhiyun /* Sift the element at pos down the heap. */
34*4882a593Smuzhiyun static __always_inline
min_heapify(struct min_heap * heap,int pos,const struct min_heap_callbacks * func)35*4882a593Smuzhiyun void min_heapify(struct min_heap *heap, int pos,
36*4882a593Smuzhiyun const struct min_heap_callbacks *func)
37*4882a593Smuzhiyun {
38*4882a593Smuzhiyun void *left, *right, *parent, *smallest;
39*4882a593Smuzhiyun void *data = heap->data;
40*4882a593Smuzhiyun
41*4882a593Smuzhiyun for (;;) {
42*4882a593Smuzhiyun if (pos * 2 + 1 >= heap->nr)
43*4882a593Smuzhiyun break;
44*4882a593Smuzhiyun
45*4882a593Smuzhiyun left = data + ((pos * 2 + 1) * func->elem_size);
46*4882a593Smuzhiyun parent = data + (pos * func->elem_size);
47*4882a593Smuzhiyun smallest = parent;
48*4882a593Smuzhiyun if (func->less(left, smallest))
49*4882a593Smuzhiyun smallest = left;
50*4882a593Smuzhiyun
51*4882a593Smuzhiyun if (pos * 2 + 2 < heap->nr) {
52*4882a593Smuzhiyun right = data + ((pos * 2 + 2) * func->elem_size);
53*4882a593Smuzhiyun if (func->less(right, smallest))
54*4882a593Smuzhiyun smallest = right;
55*4882a593Smuzhiyun }
56*4882a593Smuzhiyun if (smallest == parent)
57*4882a593Smuzhiyun break;
58*4882a593Smuzhiyun func->swp(smallest, parent);
59*4882a593Smuzhiyun if (smallest == left)
60*4882a593Smuzhiyun pos = (pos * 2) + 1;
61*4882a593Smuzhiyun else
62*4882a593Smuzhiyun pos = (pos * 2) + 2;
63*4882a593Smuzhiyun }
64*4882a593Smuzhiyun }
65*4882a593Smuzhiyun
66*4882a593Smuzhiyun /* Floyd's approach to heapification that is O(nr). */
67*4882a593Smuzhiyun static __always_inline
min_heapify_all(struct min_heap * heap,const struct min_heap_callbacks * func)68*4882a593Smuzhiyun void min_heapify_all(struct min_heap *heap,
69*4882a593Smuzhiyun const struct min_heap_callbacks *func)
70*4882a593Smuzhiyun {
71*4882a593Smuzhiyun int i;
72*4882a593Smuzhiyun
73*4882a593Smuzhiyun for (i = heap->nr / 2; i >= 0; i--)
74*4882a593Smuzhiyun min_heapify(heap, i, func);
75*4882a593Smuzhiyun }
76*4882a593Smuzhiyun
77*4882a593Smuzhiyun /* Remove minimum element from the heap, O(log2(nr)). */
78*4882a593Smuzhiyun static __always_inline
min_heap_pop(struct min_heap * heap,const struct min_heap_callbacks * func)79*4882a593Smuzhiyun void min_heap_pop(struct min_heap *heap,
80*4882a593Smuzhiyun const struct min_heap_callbacks *func)
81*4882a593Smuzhiyun {
82*4882a593Smuzhiyun void *data = heap->data;
83*4882a593Smuzhiyun
84*4882a593Smuzhiyun if (WARN_ONCE(heap->nr <= 0, "Popping an empty heap"))
85*4882a593Smuzhiyun return;
86*4882a593Smuzhiyun
87*4882a593Smuzhiyun /* Place last element at the root (position 0) and then sift down. */
88*4882a593Smuzhiyun heap->nr--;
89*4882a593Smuzhiyun memcpy(data, data + (heap->nr * func->elem_size), func->elem_size);
90*4882a593Smuzhiyun min_heapify(heap, 0, func);
91*4882a593Smuzhiyun }
92*4882a593Smuzhiyun
93*4882a593Smuzhiyun /*
94*4882a593Smuzhiyun * Remove the minimum element and then push the given element. The
95*4882a593Smuzhiyun * implementation performs 1 sift (O(log2(nr))) and is therefore more
96*4882a593Smuzhiyun * efficient than a pop followed by a push that does 2.
97*4882a593Smuzhiyun */
98*4882a593Smuzhiyun static __always_inline
min_heap_pop_push(struct min_heap * heap,const void * element,const struct min_heap_callbacks * func)99*4882a593Smuzhiyun void min_heap_pop_push(struct min_heap *heap,
100*4882a593Smuzhiyun const void *element,
101*4882a593Smuzhiyun const struct min_heap_callbacks *func)
102*4882a593Smuzhiyun {
103*4882a593Smuzhiyun memcpy(heap->data, element, func->elem_size);
104*4882a593Smuzhiyun min_heapify(heap, 0, func);
105*4882a593Smuzhiyun }
106*4882a593Smuzhiyun
107*4882a593Smuzhiyun /* Push an element on to the heap, O(log2(nr)). */
108*4882a593Smuzhiyun static __always_inline
min_heap_push(struct min_heap * heap,const void * element,const struct min_heap_callbacks * func)109*4882a593Smuzhiyun void min_heap_push(struct min_heap *heap, const void *element,
110*4882a593Smuzhiyun const struct min_heap_callbacks *func)
111*4882a593Smuzhiyun {
112*4882a593Smuzhiyun void *data = heap->data;
113*4882a593Smuzhiyun void *child, *parent;
114*4882a593Smuzhiyun int pos;
115*4882a593Smuzhiyun
116*4882a593Smuzhiyun if (WARN_ONCE(heap->nr >= heap->size, "Pushing on a full heap"))
117*4882a593Smuzhiyun return;
118*4882a593Smuzhiyun
119*4882a593Smuzhiyun /* Place at the end of data. */
120*4882a593Smuzhiyun pos = heap->nr;
121*4882a593Smuzhiyun memcpy(data + (pos * func->elem_size), element, func->elem_size);
122*4882a593Smuzhiyun heap->nr++;
123*4882a593Smuzhiyun
124*4882a593Smuzhiyun /* Sift child at pos up. */
125*4882a593Smuzhiyun for (; pos > 0; pos = (pos - 1) / 2) {
126*4882a593Smuzhiyun child = data + (pos * func->elem_size);
127*4882a593Smuzhiyun parent = data + ((pos - 1) / 2) * func->elem_size;
128*4882a593Smuzhiyun if (func->less(parent, child))
129*4882a593Smuzhiyun break;
130*4882a593Smuzhiyun func->swp(parent, child);
131*4882a593Smuzhiyun }
132*4882a593Smuzhiyun }
133*4882a593Smuzhiyun
134*4882a593Smuzhiyun #endif /* _LINUX_MIN_HEAP_H */
135