1 #pragma once
2 
3 #include <algorithm>
4 #include <cmath>
5 #include <cstdint>
6 #include <tuple>
7 #include <vector>
8 #include <cassert>
9 
10 namespace kdbush {
11 
12 template <std::uint8_t I, typename T>
13 struct nth {
getkdbush::nth14     inline static typename std::tuple_element<I, T>::type get(const T &t) {
15         return std::get<I>(t);
16     }
17 };
18 
19 template <typename TPoint, typename TIndex = std::size_t>
20 class KDBush {
21 
22 public:
23     using TNumber = decltype(nth<0, TPoint>::get(std::declval<TPoint>()));
24     static_assert(
25         std::is_same<TNumber, decltype(nth<1, TPoint>::get(std::declval<TPoint>()))>::value,
26         "point component types must be identical");
27 
28     static const std::uint8_t defaultNodeSize = 64;
29 
KDBush(const std::uint8_t nodeSize_=defaultNodeSize)30     KDBush(const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) {
31     }
32 
KDBush(const std::vector<TPoint> & points_,const std::uint8_t nodeSize_=defaultNodeSize)33     KDBush(const std::vector<TPoint> &points_, const std::uint8_t nodeSize_ = defaultNodeSize)
34         : KDBush(std::begin(points_), std::end(points_), nodeSize_) {
35     }
36 
37     template <typename TPointIter>
KDBush(const TPointIter & points_begin,const TPointIter & points_end,const std::uint8_t nodeSize_=defaultNodeSize)38     KDBush(const TPointIter &points_begin,
39            const TPointIter &points_end,
40            const std::uint8_t nodeSize_ = defaultNodeSize)
41         : nodeSize(nodeSize_) {
42         fill(points_begin, points_end);
43     }
44 
fill(const std::vector<TPoint> & points_)45     void fill(const std::vector<TPoint> &points_) {
46         fill(std::begin(points_), std::end(points_));
47     }
48 
49     template <typename TPointIter>
fill(const TPointIter & points_begin,const TPointIter & points_end)50     void fill(const TPointIter &points_begin, const TPointIter &points_end) {
51         assert(points.empty());
52         const TIndex size = static_cast<TIndex>(std::distance(points_begin, points_end));
53 
54         points.reserve(size);
55         ids.reserve(size);
56 
57         TIndex i = 0;
58         for (auto p = points_begin; p != points_end; p++) {
59             points.emplace_back(nth<0, TPoint>::get(*p), nth<1, TPoint>::get(*p));
60             ids.push_back(i++);
61         }
62 
63         sortKD(0, size - 1, 0);
64     }
65 
66     template <typename TVisitor>
range(const TNumber minX,const TNumber minY,const TNumber maxX,const TNumber maxY,const TVisitor & visitor)67     void range(const TNumber minX,
68                const TNumber minY,
69                const TNumber maxX,
70                const TNumber maxY,
71                const TVisitor &visitor) {
72         range(minX, minY, maxX, maxY, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0);
73     }
74 
75     template <typename TVisitor>
within(const TNumber qx,const TNumber qy,const TNumber r,const TVisitor & visitor)76     void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor) {
77         within(qx, qy, r, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0);
78     }
79 
80 private:
81     std::vector<TIndex> ids;
82     std::vector<std::pair<TNumber, TNumber>> points;
83     std::uint8_t nodeSize;
84 
85     template <typename TVisitor>
range(const TNumber minX,const TNumber minY,const TNumber maxX,const TNumber maxY,const TVisitor & visitor,const TIndex left,const TIndex right,const std::uint8_t axis)86     void range(const TNumber minX,
87                const TNumber minY,
88                const TNumber maxX,
89                const TNumber maxY,
90                const TVisitor &visitor,
91                const TIndex left,
92                const TIndex right,
93                const std::uint8_t axis) {
94 
95         if (right - left <= nodeSize) {
96             for (auto i = left; i <= right; i++) {
97                 const TNumber x = std::get<0>(points[i]);
98                 const TNumber y = std::get<1>(points[i]);
99                 if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[i]);
100             }
101             return;
102         }
103 
104         const TIndex m = (left + right) >> 1;
105         const TNumber x = std::get<0>(points[m]);
106         const TNumber y = std::get<1>(points[m]);
107 
108         if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[m]);
109 
110         if (axis == 0 ? minX <= x : minY <= y)
111             range(minX, minY, maxX, maxY, visitor, left, m - 1, (axis + 1) % 2);
112 
113         if (axis == 0 ? maxX >= x : maxY >= y)
114             range(minX, minY, maxX, maxY, visitor, m + 1, right, (axis + 1) % 2);
115     }
116 
117     template <typename TVisitor>
within(const TNumber qx,const TNumber qy,const TNumber r,const TVisitor & visitor,const TIndex left,const TIndex right,const std::uint8_t axis)118     void within(const TNumber qx,
119                 const TNumber qy,
120                 const TNumber r,
121                 const TVisitor &visitor,
122                 const TIndex left,
123                 const TIndex right,
124                 const std::uint8_t axis) {
125 
126         const TNumber r2 = r * r;
127 
128         if (right - left <= nodeSize) {
129             for (auto i = left; i <= right; i++) {
130                 const TNumber x = std::get<0>(points[i]);
131                 const TNumber y = std::get<1>(points[i]);
132                 if (sqDist(x, y, qx, qy) <= r2) visitor(ids[i]);
133             }
134             return;
135         }
136 
137         const TIndex m = (left + right) >> 1;
138         const TNumber x = std::get<0>(points[m]);
139         const TNumber y = std::get<1>(points[m]);
140 
141         if (sqDist(x, y, qx, qy) <= r2) visitor(ids[m]);
142 
143         if (axis == 0 ? qx - r <= x : qy - r <= y)
144             within(qx, qy, r, visitor, left, m - 1, (axis + 1) % 2);
145 
146         if (axis == 0 ? qx + r >= x : qy + r >= y)
147             within(qx, qy, r, visitor, m + 1, right, (axis + 1) % 2);
148     }
149 
sortKD(const TIndex left,const TIndex right,const std::uint8_t axis)150     void sortKD(const TIndex left, const TIndex right, const std::uint8_t axis) {
151         if (right - left <= nodeSize) return;
152         const TIndex m = (left + right) >> 1;
153         if (axis == 0) {
154             select<0>(m, left, right);
155         } else {
156             select<1>(m, left, right);
157         }
158         sortKD(left, m - 1, (axis + 1) % 2);
159         sortKD(m + 1, right, (axis + 1) % 2);
160     }
161 
162     template <std::uint8_t I>
select(const TIndex k,TIndex left,TIndex right)163     void select(const TIndex k, TIndex left, TIndex right) {
164 
165         while (right > left) {
166             if (right - left > 600) {
167                 const double n = right - left + 1;
168                 const double m = k - left + 1;
169                 const double z = std::log(n);
170                 const double s = 0.5 * std::exp(2 * z / 3);
171                 const double r =
172                     k - m * s / n + 0.5 * std::sqrt(z * s * (1 - s / n)) * (2 * m < n ? -1 : 1);
173                 select<I>(k, std::max(left, TIndex(r)), std::min(right, TIndex(r + s)));
174             }
175 
176             const TNumber t = std::get<I>(points[k]);
177             TIndex i = left;
178             TIndex j = right;
179 
180             swapItem(left, k);
181             if (std::get<I>(points[right]) > t) swapItem(left, right);
182 
183             while (i < j) {
184                 swapItem(i++, j--);
185                 while (std::get<I>(points[i]) < t) i++;
186                 while (std::get<I>(points[j]) > t) j--;
187             }
188 
189             if (std::get<I>(points[left]) == t)
190                 swapItem(left, j);
191             else {
192                 swapItem(++j, right);
193             }
194 
195             if (j <= k) left = j + 1;
196             if (k <= j) right = j - 1;
197         }
198     }
199 
swapItem(const TIndex i,const TIndex j)200     void swapItem(const TIndex i, const TIndex j) {
201         std::iter_swap(ids.begin() + i, ids.begin() + j);
202         std::iter_swap(points.begin() + i, points.begin() + j);
203     }
204 
sqDist(const TNumber ax,const TNumber ay,const TNumber bx,const TNumber by)205     TNumber sqDist(const TNumber ax, const TNumber ay, const TNumber bx, const TNumber by) {
206         return std::pow(ax - bx, 2) + std::pow(ay - by, 2);
207     }
208 };
209 
210 } // namespace kdbush
211