Created
March 18, 2014 06:53
-
-
Save cairijun/9614846 to your computer and use it in GitHub Desktop.
vq with BLAS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/scipy/cluster/_vq_rewrite.pyx b/scipy/cluster/_vq_rewrite.pyx | |
index d9fe201..29318af 100644 | |
--- a/scipy/cluster/_vq_rewrite.pyx | |
+++ b/scipy/cluster/_vq_rewrite.pyx | |
@@ -9,6 +9,8 @@ Translated to Cython by David Warde-Farley, October 2009. | |
import numpy as np | |
cimport numpy as np | |
+from cpython.mem cimport PyMem_Malloc, PyMem_Free | |
+from libc.string cimport memset | |
cdef extern from "math.h": | |
float sqrtf(float num) | |
@@ -23,6 +25,31 @@ cdef extern from "numpy/npy_math.h": | |
cdef enum: | |
NPY_INFINITY | |
+ | |
+cdef extern from "cblas.h": | |
+ cdef enum CBLAS_ORDER: | |
+ CblasRowMajor=101 | |
+ CblasColMajor=102 | |
+ cdef enum CBLAS_TRANSPOSE: | |
+ CblasNoTrans=111 | |
+ CblasTrans=112 | |
+ CblasConjTrans=113 | |
+ AtlasConj=114 | |
+ | |
+ void cblas_sgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,\ | |
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,\ | |
+ const int K, const float alpha, const float *A,\ | |
+ const int lda, const float *B, const int ldb,\ | |
+ const float beta, float *C, const int ldc) | |
+ | |
+ float cblas_sdot(const int N, const float *X, const int incX,\ | |
+ const float *Y, const int incY) | |
+ | |
+ void cblas_sger(const CBLAS_ORDER Order, const int M, const int N,\ | |
+ const float alpha, const float *X, const int incX,\ | |
+ const float *Y, const int incY, float *A, const int lda) | |
+ | |
+ | |
# C types | |
ctypedef np.float64_t float64_t | |
ctypedef np.float32_t float32_t | |
@@ -46,33 +73,55 @@ cdef void float_tvq(float32_t *obs, float32_t *code_book, | |
# Index and pointer to keep track of the current position in | |
# both arrays so that we don't have to always do index * nfeat. | |
cdef int codebook_pos | |
- cdef float32_t *current_obs | |
- | |
+ cdef float32_t *current_obs | |
+ cdef float32_t *current_code | |
+ cdef float32_t *dis_matrix_p | |
+ cdef float32_t *dis_matrix | |
+ cdef float32_t *obs_sqr | |
+ cdef float32_t *codes_sqr | |
+ cdef float32_t *ones_nobs | |
+ cdef float32_t *ones_ncodes | |
+ | |
+ dis_matrix = <float32_t *>PyMem_Malloc(nobs * ncodes * sizeof(float32_t)) | |
+ obs_sqr = <float32_t *>PyMem_Malloc(nobs * sizeof(float32_t)) | |
+ codes_sqr = <float32_t *>PyMem_Malloc(ncodes * sizeof(float32_t)) | |
+ ones_nobs = <float32_t *>PyMem_Malloc(nobs * sizeof(float32_t)) | |
+ ones_ncodes = <float32_t *>PyMem_Malloc(ncodes * sizeof(float32_t)) | |
+ memset(dis_matrix, 0, nobs * ncodes * sizeof(float32_t)) | |
+ | |
+ current_obs = obs | |
for obs_index in range(nobs): | |
- codebook_pos = 0 | |
- low_dist[obs_index] = NPY_INFINITY | |
+ ones_nobs[obs_index] = 1.0 | |
+ obs_sqr[obs_index] = cblas_sdot(nfeat, current_obs, 1, current_obs, 1) | |
+ current_obs += nfeat | |
+ | |
+ current_code = code_book | |
+ for code_index in range(ncodes): | |
+ ones_ncodes[code_index] = 1.0 | |
+ codes_sqr[code_index] = cblas_sdot(nfeat, current_code, 1, current_code, 1) | |
+ current_code += nfeat | |
+ | |
+ cblas_sger(CblasRowMajor, nobs, ncodes, 1.0, obs_sqr, 1, ones_ncodes, 1, dis_matrix, ncodes) | |
+ cblas_sger(CblasRowMajor, nobs, ncodes, 1.0, ones_nobs, 1, codes_sqr, 1, dis_matrix, ncodes) | |
+ | |
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nobs, ncodes, nfeat, -2.0, | |
+ obs, nfeat, code_book, nfeat, 1.0, dis_matrix, ncodes) | |
+ | |
+ dis_matrix_p = dis_matrix | |
+ for obs_index in range(nobs): | |
+ dist = NPY_INFINITY | |
for code_index in range(ncodes): | |
- dist = 0 | |
- | |
- # Distance between code_book[code_index] and obs[obs_index] | |
- for feature in range(nfeat): | |
- | |
- # Use current_obs pointer and codebook_pos to minimize | |
- # pointer arithmetic necessary (i.e. no multiplications) | |
- current_obs = &(obs[offset]) | |
- diff = code_book[codebook_pos] - current_obs[feature] | |
- dist += diff * diff | |
- codebook_pos += 1 | |
- | |
- dist = sqrtf(dist) | |
- | |
- # Replace the code assignment and record distance if necessary | |
- if dist < low_dist[obs_index]: | |
+ if dis_matrix_p[code_index] < dist: | |
codes[obs_index] = code_index | |
- low_dist[obs_index] = dist | |
- | |
- # Update the offset of the current observation | |
- offset += nfeat | |
+ dist = dis_matrix_p[code_index] | |
+ low_dist[obs_index] = sqrtf(dist) | |
+ dis_matrix_p += ncodes | |
+ | |
+ PyMem_Free(dis_matrix) | |
+ PyMem_Free(obs_sqr) | |
+ PyMem_Free(codes_sqr) | |
+ PyMem_Free(ones_nobs) | |
+ PyMem_Free(ones_ncodes) | |
def vq(np.ndarray obs, np.ndarray codes): | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment