Skip to content

Instantly share code, notes, and snippets.

@cairijun
Created March 18, 2014 06:53
Show Gist options
  • Save cairijun/9614846 to your computer and use it in GitHub Desktop.
Save cairijun/9614846 to your computer and use it in GitHub Desktop.
vq with BLAS
diff --git a/scipy/cluster/_vq_rewrite.pyx b/scipy/cluster/_vq_rewrite.pyx
index d9fe201..29318af 100644
--- a/scipy/cluster/_vq_rewrite.pyx
+++ b/scipy/cluster/_vq_rewrite.pyx
@@ -9,6 +9,8 @@ Translated to Cython by David Warde-Farley, October 2009.
import numpy as np
cimport numpy as np
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+from libc.string cimport memset
cdef extern from "math.h":
float sqrtf(float num)
@@ -23,6 +25,31 @@ cdef extern from "numpy/npy_math.h":
cdef enum:
NPY_INFINITY
+
+cdef extern from "cblas.h":
+ cdef enum CBLAS_ORDER:
+ CblasRowMajor=101
+ CblasColMajor=102
+ cdef enum CBLAS_TRANSPOSE:
+ CblasNoTrans=111
+ CblasTrans=112
+ CblasConjTrans=113
+ AtlasConj=114
+
+ void cblas_sgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA,\
+ const CBLAS_TRANSPOSE TransB, const int M, const int N,\
+ const int K, const float alpha, const float *A,\
+ const int lda, const float *B, const int ldb,\
+ const float beta, float *C, const int ldc)
+
+ float cblas_sdot(const int N, const float *X, const int incX,\
+ const float *Y, const int incY)
+
+ void cblas_sger(const CBLAS_ORDER Order, const int M, const int N,\
+ const float alpha, const float *X, const int incX,\
+ const float *Y, const int incY, float *A, const int lda)
+
+
# C types
ctypedef np.float64_t float64_t
ctypedef np.float32_t float32_t
@@ -46,33 +73,55 @@ cdef void float_tvq(float32_t *obs, float32_t *code_book,
# Index and pointer to keep track of the current position in
# both arrays so that we don't have to always do index * nfeat.
cdef int codebook_pos
- cdef float32_t *current_obs
-
+ cdef float32_t *current_obs
+ cdef float32_t *current_code
+ cdef float32_t *dis_matrix_p
+ cdef float32_t *dis_matrix
+ cdef float32_t *obs_sqr
+ cdef float32_t *codes_sqr
+ cdef float32_t *ones_nobs
+ cdef float32_t *ones_ncodes
+
+ dis_matrix = <float32_t *>PyMem_Malloc(nobs * ncodes * sizeof(float32_t))
+ obs_sqr = <float32_t *>PyMem_Malloc(nobs * sizeof(float32_t))
+ codes_sqr = <float32_t *>PyMem_Malloc(ncodes * sizeof(float32_t))
+ ones_nobs = <float32_t *>PyMem_Malloc(nobs * sizeof(float32_t))
+ ones_ncodes = <float32_t *>PyMem_Malloc(ncodes * sizeof(float32_t))
+ memset(dis_matrix, 0, nobs * ncodes * sizeof(float32_t))
+
+ current_obs = obs
for obs_index in range(nobs):
- codebook_pos = 0
- low_dist[obs_index] = NPY_INFINITY
+ ones_nobs[obs_index] = 1.0
+ obs_sqr[obs_index] = cblas_sdot(nfeat, current_obs, 1, current_obs, 1)
+ current_obs += nfeat
+
+ current_code = code_book
+ for code_index in range(ncodes):
+ ones_ncodes[code_index] = 1.0
+ codes_sqr[code_index] = cblas_sdot(nfeat, current_code, 1, current_code, 1)
+ current_code += nfeat
+
+ cblas_sger(CblasRowMajor, nobs, ncodes, 1.0, obs_sqr, 1, ones_ncodes, 1, dis_matrix, ncodes)
+ cblas_sger(CblasRowMajor, nobs, ncodes, 1.0, ones_nobs, 1, codes_sqr, 1, dis_matrix, ncodes)
+
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nobs, ncodes, nfeat, -2.0,
+ obs, nfeat, code_book, nfeat, 1.0, dis_matrix, ncodes)
+
+ dis_matrix_p = dis_matrix
+ for obs_index in range(nobs):
+ dist = NPY_INFINITY
for code_index in range(ncodes):
- dist = 0
-
- # Distance between code_book[code_index] and obs[obs_index]
- for feature in range(nfeat):
-
- # Use current_obs pointer and codebook_pos to minimize
- # pointer arithmetic necessary (i.e. no multiplications)
- current_obs = &(obs[offset])
- diff = code_book[codebook_pos] - current_obs[feature]
- dist += diff * diff
- codebook_pos += 1
-
- dist = sqrtf(dist)
-
- # Replace the code assignment and record distance if necessary
- if dist < low_dist[obs_index]:
+ if dis_matrix_p[code_index] < dist:
codes[obs_index] = code_index
- low_dist[obs_index] = dist
-
- # Update the offset of the current observation
- offset += nfeat
+ dist = dis_matrix_p[code_index]
+ low_dist[obs_index] = sqrtf(dist)
+ dis_matrix_p += ncodes
+
+ PyMem_Free(dis_matrix)
+ PyMem_Free(obs_sqr)
+ PyMem_Free(codes_sqr)
+ PyMem_Free(ones_nobs)
+ PyMem_Free(ones_ncodes)
def vq(np.ndarray obs, np.ndarray codes):
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment