Last active
March 9, 2022 19:38
-
-
Save juaxix/458b9a61654803017bcbe249a51d4be7 to your computer and use it in GitHub Desktop.
C++ KD Tree with Vector3 and a condition flag
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <memory> | |
#include <math.h> | |
#include <algorithm> | |
#include <vector> | |
#include <string> | |
using namespace std; | |
void padTo(string &str, const size_t num, const char paddingChar = ' ') | |
{ | |
if(num > str.size()) | |
str.insert(0, num - str.size(), paddingChar); | |
} | |
struct Vector3{ | |
public: | |
float X; | |
float Y; | |
float Z; | |
static float Distance(const Vector3 &v1,const Vector3& v2) | |
{ | |
return (float)sqrt | |
( | |
(v1.X - v2.X) * (v1.X - v2.X) + | |
(v1.Y - v2.Y) * (v1.Y - v2.Y) + | |
(v1.Z - v2.Z) * (v1.Z - v2.Z) | |
); | |
}; | |
float operator[] (int i){ | |
switch(i){ | |
default: case 0: return X; | |
case 1: return Y; | |
case 2: return Z; | |
} | |
}; | |
Vector3 operator-(const Vector3& b){ | |
Vector3 r; | |
r.X = X-b.X; r.Y=Y-b.Y; r.Z=Z-b.Z; | |
return r; | |
}; | |
bool operator==(const Vector3& b){ | |
return X==b.X&&Y==b.Y&&Z==b.Z; | |
} | |
float sqrMagnitude(){ | |
return X * X + Y * Y + Z * Z; | |
}; | |
static Vector3 zero(){ | |
static Vector3 z{0,0,0}; | |
return z; | |
}; | |
string ToString(){ | |
return "("+to_string(X)+","+to_string(Y)+","+to_string(Z)+")"; | |
} | |
}; | |
struct TreeNode{ | |
public: | |
Vector3 position; | |
bool healed; | |
}; | |
class KDTree | |
{ | |
private: | |
static int callcounter; | |
public: | |
static TreeNode lastPivot; | |
vector<shared_ptr<KDTree>> lr; | |
TreeNode pivot; | |
int pivotIndex; | |
int axis; | |
// Change this value to 3 if you need three-dimensional X,Y,Z points. The search will be quicker in two dimensions. | |
static const int numDims = 3; | |
KDTree() : lr() | |
{ | |
lr.reserve(2); | |
} | |
// Make a new tree from a list of points. | |
static shared_ptr<KDTree> MakeFromPoints(vector<TreeNode>& points) { | |
vector<int> indices = Iota(points.size()); | |
return MakeFromPointsInner(0, 0, points.size() - 1, points, indices); | |
} | |
// Recursively build a tree by separating points at plane boundaries. | |
static shared_ptr<KDTree> MakeFromPointsInner( | |
int depth, | |
int stIndex, int enIndex, | |
vector<TreeNode>& points, | |
vector<int>& inds | |
) | |
{ | |
shared_ptr<KDTree> root = make_shared<KDTree>(); | |
root->axis = depth % KDTree::numDims; | |
int splitPoint = FindPivotIndex(points, inds, stIndex, enIndex, root->axis); | |
root->pivotIndex = inds[splitPoint]; | |
root->pivot = points[root->pivotIndex]; | |
int leftEndIndex = splitPoint - 1; | |
if (leftEndIndex >= stIndex) { | |
root->lr[0] = MakeFromPointsInner(depth + 1, stIndex, leftEndIndex, points, inds); | |
} | |
int rightStartIndex = splitPoint + 1; | |
if (rightStartIndex <= enIndex) { | |
root->lr[1] = MakeFromPointsInner(depth + 1, rightStartIndex, enIndex, points, inds); | |
} | |
return root; | |
} | |
static void SwapElements(vector<int> &arr, int a, int b) { | |
int temp = arr[a]; | |
arr[a] = arr[b]; | |
arr[b] = temp; | |
} | |
// Simple "median of three" heuristic to find a reasonable splitting plane. | |
static int FindSplitPoint(vector<TreeNode>& points, const vector<int>& inds, int stIndex, int enIndex, int axis) { | |
float a = points[inds[stIndex]].position[axis]; | |
float b = points[inds[enIndex]].position[axis]; | |
int midIndex = (stIndex + enIndex) / 2; | |
float m = points[inds[midIndex]].position[axis]; | |
if (a > b) { | |
if (m > a) { | |
return stIndex; | |
} | |
if (b > m) { | |
return enIndex; | |
} | |
return midIndex; | |
} else { | |
if (a > m) { | |
return stIndex; | |
} | |
if (m > b) { | |
return enIndex; | |
} | |
return midIndex; | |
} | |
} | |
// Find a new pivot index from the range by splitting the points that fall either side | |
// of its plane. | |
static int FindPivotIndex(vector<TreeNode>& points, vector<int>& inds, int stIndex, int enIndex, int axis) { | |
int splitPoint = FindSplitPoint(points, inds, stIndex, enIndex, axis); | |
// int splitPoint = Random.Range(stIndex, enIndex); | |
Vector3 pivot = points[inds[splitPoint]].position; | |
SwapElements(inds, stIndex, splitPoint); | |
int currPt = stIndex + 1; | |
int endPt = enIndex; | |
while (currPt <= endPt) { | |
Vector3 curr = points[inds[currPt]].position; | |
if ((curr[axis] > pivot[axis])) { | |
SwapElements(inds, currPt, endPt); | |
endPt--; | |
} else { | |
SwapElements(inds, currPt - 1, currPt); | |
currPt++; | |
} | |
} | |
return currPt - 1; | |
} | |
static vector<int> Iota(int num) { | |
vector<int> result(num); | |
for (int i = 0; i < num; i++) { | |
result[i] = i; | |
} | |
return result; | |
} | |
// Find the nearest point in the set to the supplied point. | |
int FindNearest(Vector3 pt) { | |
float bestSqDist = std::numeric_limits<float>::max(); | |
int bestIndex = -1; | |
Search(pt, bestSqDist, bestIndex); | |
return bestIndex; | |
} | |
// Recursively search the tree. | |
void Search(Vector3& pt, float &bestSqSoFar, int &bestIndex) | |
{ | |
float mySqDist = std::numeric_limits<float>::max(); | |
if (!pivot.healed) { | |
mySqDist = (pivot.position - pt).sqrMagnitude(); | |
} | |
if (mySqDist < bestSqSoFar) { | |
bestSqSoFar = mySqDist; | |
bestIndex = pivotIndex; | |
if (lastPivot.position==Vector3::zero() || | |
Vector3::Distance( | |
pt,pivot.position | |
)< | |
Vector3::Distance( | |
pt,lastPivot.position | |
) | |
){ | |
callcounter++; | |
lastPivot = TreeNode(); | |
lastPivot.position=pivot.position; | |
lastPivot.healed=pivot.healed; | |
//Debug.Log(callcounter.ToString() + ": New position " + bestIndex.ToString()+ "; " + lastPivot.position.ToString()); | |
} | |
} | |
float planeDist = pt[axis] - pivot.position[axis]; //DistFromSplitPlane(pt, pivot, axis); | |
int selector = planeDist <= 0 ? 0 : 1; | |
if (lr[selector] != nullptr) { | |
lr[selector]->Search(pt, bestSqSoFar, bestIndex); | |
} | |
selector = (selector + 1) % 2; | |
float sqPlaneDist = planeDist * planeDist; | |
if ((lr[selector] != nullptr) && (bestSqSoFar > sqPlaneDist)) { | |
lr[selector]->Search(pt, bestSqSoFar, bestIndex); | |
} | |
} | |
// Get a point's distance from an axis-aligned plane. | |
float DistFromSplitPlane(Vector3 pt, Vector3 planePt, int axis) | |
{ | |
return pt[axis] - planePt[axis]; | |
} | |
// Simple output of tree structure - mainly useful for getting a rough | |
// idea of how deep the tree is (and therefore how well the splitting | |
// heuristic is performing). | |
string Dump(int level) | |
{ | |
string result = to_string(pivotIndex); | |
padTo(result,level); | |
result += "\n"; | |
if (lr[0] != nullptr) { | |
result += lr[0]->Dump(level + 2); | |
} | |
if (lr[1] != nullptr) { | |
result += lr[1]->Dump(level + 2); | |
} | |
return result; | |
} | |
}; | |
int KDTree::callcounter = 0; | |
TreeNode KDTree::lastPivot=TreeNode(); | |
int main(int argc,char **argv) | |
{ | |
const int total_trees = 600; | |
vector<TreeNode> vectors(total_trees); | |
if(argc<4) { | |
cout << "Please use X,Y,Z coordinates to find in the tree of trees"<<endl; | |
cout << argv[0] << " X Y Z"<<endl; | |
} | |
Vector3 coordinates; | |
coordinates.X = atof(argv[1]); | |
coordinates.Y = atof(argv[2]); | |
coordinates.Z = atof(argv[3]); | |
for (int i=0; i<total_trees; i++) | |
{ | |
vectors[i] = TreeNode(); | |
vectors[i].position.X = i*6; | |
vectors[i].position.Y = i*2; | |
vectors[i].position.Z = i; | |
vectors[i].healed = false; | |
} | |
shared_ptr<KDTree> TreeInstances = KDTree::MakeFromPoints(vectors); | |
cout << "Trees created: " << total_trees << endl; | |
cout << "---------------------"<<endl; | |
cout << "Searching for the index of the tree nearest to "<< | |
coordinates.ToString()<<endl; | |
int index = TreeInstances->FindNearest(coordinates); | |
cout << "Nearest tree at: " << index << vectors[index].position.ToString(); | |
TreeInstances.reset(); | |
vectors.clear(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment