1 #pragma once 2 3 #include <algorithm> 4 #include <cmath> 5 #include <cstdint> 6 #include <tuple> 7 #include <vector> 8 #include <cassert> 9 10 namespace kdbush { 11 12 template <std::uint8_t I, typename T> 13 struct nth { getkdbush::nth14 inline static typename std::tuple_element<I, T>::type get(const T &t) { 15 return std::get<I>(t); 16 } 17 }; 18 19 template <typename TPoint, typename TIndex = std::size_t> 20 class KDBush { 21 22 public: 23 using TNumber = decltype(nth<0, TPoint>::get(std::declval<TPoint>())); 24 static_assert( 25 std::is_same<TNumber, decltype(nth<1, TPoint>::get(std::declval<TPoint>()))>::value, 26 "point component types must be identical"); 27 28 static const std::uint8_t defaultNodeSize = 64; 29 KDBush(const std::uint8_t nodeSize_=defaultNodeSize)30 KDBush(const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) { 31 } 32 KDBush(const std::vector<TPoint> & points_,const std::uint8_t nodeSize_=defaultNodeSize)33 KDBush(const std::vector<TPoint> &points_, const std::uint8_t nodeSize_ = defaultNodeSize) 34 : KDBush(std::begin(points_), std::end(points_), nodeSize_) { 35 } 36 37 template <typename TPointIter> KDBush(const TPointIter & points_begin,const TPointIter & points_end,const std::uint8_t nodeSize_=defaultNodeSize)38 KDBush(const TPointIter &points_begin, 39 const TPointIter &points_end, 40 const std::uint8_t nodeSize_ = defaultNodeSize) 41 : nodeSize(nodeSize_) { 42 fill(points_begin, points_end); 43 } 44 fill(const std::vector<TPoint> & points_)45 void fill(const std::vector<TPoint> &points_) { 46 fill(std::begin(points_), std::end(points_)); 47 } 48 49 template <typename TPointIter> fill(const TPointIter & points_begin,const TPointIter & points_end)50 void fill(const TPointIter &points_begin, const TPointIter &points_end) { 51 assert(points.empty()); 52 const TIndex size = static_cast<TIndex>(std::distance(points_begin, points_end)); 53 54 points.reserve(size); 55 ids.reserve(size); 56 57 TIndex i = 0; 58 for (auto p = points_begin; p != points_end; p++) { 59 points.emplace_back(nth<0, TPoint>::get(*p), nth<1, TPoint>::get(*p)); 60 ids.push_back(i++); 61 } 62 63 sortKD(0, size - 1, 0); 64 } 65 66 template <typename TVisitor> range(const TNumber minX,const TNumber minY,const TNumber maxX,const TNumber maxY,const TVisitor & visitor)67 void range(const TNumber minX, 68 const TNumber minY, 69 const TNumber maxX, 70 const TNumber maxY, 71 const TVisitor &visitor) { 72 range(minX, minY, maxX, maxY, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0); 73 } 74 75 template <typename TVisitor> within(const TNumber qx,const TNumber qy,const TNumber r,const TVisitor & visitor)76 void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor) { 77 within(qx, qy, r, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0); 78 } 79 80 private: 81 std::vector<TIndex> ids; 82 std::vector<std::pair<TNumber, TNumber>> points; 83 std::uint8_t nodeSize; 84 85 template <typename TVisitor> range(const TNumber minX,const TNumber minY,const TNumber maxX,const TNumber maxY,const TVisitor & visitor,const TIndex left,const TIndex right,const std::uint8_t axis)86 void range(const TNumber minX, 87 const TNumber minY, 88 const TNumber maxX, 89 const TNumber maxY, 90 const TVisitor &visitor, 91 const TIndex left, 92 const TIndex right, 93 const std::uint8_t axis) { 94 95 if (right - left <= nodeSize) { 96 for (auto i = left; i <= right; i++) { 97 const TNumber x = std::get<0>(points[i]); 98 const TNumber y = std::get<1>(points[i]); 99 if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[i]); 100 } 101 return; 102 } 103 104 const TIndex m = (left + right) >> 1; 105 const TNumber x = std::get<0>(points[m]); 106 const TNumber y = std::get<1>(points[m]); 107 108 if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[m]); 109 110 if (axis == 0 ? minX <= x : minY <= y) 111 range(minX, minY, maxX, maxY, visitor, left, m - 1, (axis + 1) % 2); 112 113 if (axis == 0 ? maxX >= x : maxY >= y) 114 range(minX, minY, maxX, maxY, visitor, m + 1, right, (axis + 1) % 2); 115 } 116 117 template <typename TVisitor> within(const TNumber qx,const TNumber qy,const TNumber r,const TVisitor & visitor,const TIndex left,const TIndex right,const std::uint8_t axis)118 void within(const TNumber qx, 119 const TNumber qy, 120 const TNumber r, 121 const TVisitor &visitor, 122 const TIndex left, 123 const TIndex right, 124 const std::uint8_t axis) { 125 126 const TNumber r2 = r * r; 127 128 if (right - left <= nodeSize) { 129 for (auto i = left; i <= right; i++) { 130 const TNumber x = std::get<0>(points[i]); 131 const TNumber y = std::get<1>(points[i]); 132 if (sqDist(x, y, qx, qy) <= r2) visitor(ids[i]); 133 } 134 return; 135 } 136 137 const TIndex m = (left + right) >> 1; 138 const TNumber x = std::get<0>(points[m]); 139 const TNumber y = std::get<1>(points[m]); 140 141 if (sqDist(x, y, qx, qy) <= r2) visitor(ids[m]); 142 143 if (axis == 0 ? qx - r <= x : qy - r <= y) 144 within(qx, qy, r, visitor, left, m - 1, (axis + 1) % 2); 145 146 if (axis == 0 ? qx + r >= x : qy + r >= y) 147 within(qx, qy, r, visitor, m + 1, right, (axis + 1) % 2); 148 } 149 sortKD(const TIndex left,const TIndex right,const std::uint8_t axis)150 void sortKD(const TIndex left, const TIndex right, const std::uint8_t axis) { 151 if (right - left <= nodeSize) return; 152 const TIndex m = (left + right) >> 1; 153 if (axis == 0) { 154 select<0>(m, left, right); 155 } else { 156 select<1>(m, left, right); 157 } 158 sortKD(left, m - 1, (axis + 1) % 2); 159 sortKD(m + 1, right, (axis + 1) % 2); 160 } 161 162 template <std::uint8_t I> select(const TIndex k,TIndex left,TIndex right)163 void select(const TIndex k, TIndex left, TIndex right) { 164 165 while (right > left) { 166 if (right - left > 600) { 167 const double n = right - left + 1; 168 const double m = k - left + 1; 169 const double z = std::log(n); 170 const double s = 0.5 * std::exp(2 * z / 3); 171 const double r = 172 k - m * s / n + 0.5 * std::sqrt(z * s * (1 - s / n)) * (2 * m < n ? -1 : 1); 173 select<I>(k, std::max(left, TIndex(r)), std::min(right, TIndex(r + s))); 174 } 175 176 const TNumber t = std::get<I>(points[k]); 177 TIndex i = left; 178 TIndex j = right; 179 180 swapItem(left, k); 181 if (std::get<I>(points[right]) > t) swapItem(left, right); 182 183 while (i < j) { 184 swapItem(i++, j--); 185 while (std::get<I>(points[i]) < t) i++; 186 while (std::get<I>(points[j]) > t) j--; 187 } 188 189 if (std::get<I>(points[left]) == t) 190 swapItem(left, j); 191 else { 192 swapItem(++j, right); 193 } 194 195 if (j <= k) left = j + 1; 196 if (k <= j) right = j - 1; 197 } 198 } 199 swapItem(const TIndex i,const TIndex j)200 void swapItem(const TIndex i, const TIndex j) { 201 std::iter_swap(ids.begin() + i, ids.begin() + j); 202 std::iter_swap(points.begin() + i, points.begin() + j); 203 } 204 sqDist(const TNumber ax,const TNumber ay,const TNumber bx,const TNumber by)205 TNumber sqDist(const TNumber ax, const TNumber ay, const TNumber bx, const TNumber by) { 206 return std::pow(ax - bx, 2) + std::pow(ay - by, 2); 207 } 208 }; 209 210 } // namespace kdbush 211