Created
February 8, 2017 10:17
-
-
Save blackball/1f22896d7cc4d054702db8bbbdcb2512 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "paraknn.h" | |
#define SIFTDIM 128 | |
// need AVX2 + FMA, compiler options: -march=native -O3 | |
static inline float | |
_Distance2_SIFT128(const float *va, const float *vb) { | |
__m256 zmm= _mm256_set1_ps(0.f), amm, bmm; | |
#define AVX_ONCE(start) \ | |
amm = _mm256_loadu_ps(va + start); \ | |
bmm = _mm256_loadu_ps(vb + start); \ | |
amm = _mm256_sub_ps(amm, bmm); \ | |
zmm = _mm256_fmadd_ps(amm, amm, zmm) | |
AVX_ONCE(0); | |
AVX_ONCE(8); | |
AVX_ONCE(16); | |
AVX_ONCE(24); | |
AVX_ONCE(32); | |
AVX_ONCE(40); | |
AVX_ONCE(48); | |
AVX_ONCE(56); | |
AVX_ONCE(64); | |
AVX_ONCE(72); | |
AVX_ONCE(80); | |
AVX_ONCE(88); | |
AVX_ONCE(96); | |
AVX_ONCE(104); | |
AVX_ONCE(112); | |
AVX_ONCE(120); | |
#undef AVX_ONCE | |
// add up all | |
const __m256 t1 = _mm256_hadd_ps(zmm, zmm); | |
const __m256 t2 = _mm256_hadd_ps(t1, t1); | |
const __m128 t3 = _mm256_extractf128_ps(t2, 1); | |
const __m128 t4 = _mm_add_ss(_mm256_castps256_ps128(t2), t3); | |
return _mm_cvtss_f32(t4); | |
} | |
static inline void | |
_Get_2NN(const float *query, const float *trains, const size_t trainnum, int indices[2], float distances[2]) { | |
int id0 = 0, id1 = 1; | |
float d0, d1; // keep smallest in d0, second smallest in d1 | |
d0 = _Distance2_SIFT128(query, trains); | |
d1 = _Distance2_SIFT128(query, trains + SIFTDIM); | |
if (d0 > d1) { | |
float tmp = d0; | |
d0 = d1; | |
d1 = tmp; | |
id0 = 1; | |
id1 = 0; | |
} | |
for (size_t i = 2; i < trainnum; ++i) { | |
float dist = _Distance2_SIFT128(query, trains + i * (SIFTDIM)); | |
if (dist < d0) { | |
d1 = d0; id1 = id0; | |
d0 = dist; id0 = i; | |
} | |
else if (dist < d1) { | |
d1 = dist; | |
id1 = i; | |
} | |
} | |
indices[0] = id0; | |
indices[1] = id1; | |
distances[0] = sqrt(d0); | |
distances[1] = sqrt(d1); | |
} | |
// shared data for parallelelizing | |
struct _MTData { | |
_MTData(const float *querydescriptors, | |
const size_t querynum, | |
const float *traindescriptors, | |
const size_t trainnum, | |
std::vector<DMatch2NN> &matches) | |
: _querydescriptors(querydescriptors), | |
_querynum(querynum), | |
_traindescriptors(traindescriptors), | |
_trainnum(trainnum), | |
_matches(matches) | |
{ | |
;// empty | |
} | |
const float *_querydescriptors; | |
const size_t _querynum; | |
const float *_traindescriptors; | |
const size_t _trainnum; | |
std::vector<DMatch2NN> &_matches; | |
}; | |
static void | |
_MTCallback(void *data, const int index, const int tid) { | |
_MTData *mdata = (_MTData *)data; | |
DMatch2NN &match = mdata->matches[index]; | |
_Get_2NN(mdata->_querydescriptors + index * (SIFTDIM), | |
mdata->_traindescriptors, | |
mdata->_trainnum, | |
match.trainIdx, | |
match.distance); | |
match.queryIdx[0] = match.queryIdx[1] = index; | |
} | |
int | |
BruteForce_2NN_SIFT128(const float *querydescriptors, | |
const size_t querynum, | |
const float *traindescriptors, | |
const size_t trainnum, | |
std::vector<DMatch2NN> &matches, | |
const size_t numthreads = 1) { | |
if (querynum < 1 || trainnum < 2 || numthreads < 1 || numthreads > 32) { | |
return -1; | |
} | |
matches.resize(querynum); | |
if (numthreads == 1) { | |
for (size_t i = 0; i < querynum; ++i) { | |
_Get_2NN(querydescriptors + i*(SIFTDIM), traindescriptors, trainnum, matches[i].trainIdx, matches[i].distance); | |
} | |
matches[i].queryIdx[0] = matches[i].queryIdx[1] = i; | |
} | |
else { | |
_MTData mtdata(querydescriptors, querynum, traindescriptors, trainnum, matches); | |
ParallelizedFor(numthreads, querynum, (void *)mtdata, &_MTCallback); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment