Created
March 16, 2012 18:25
Revisions
-
zed created this gist
Mar 16, 2012 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,24 @@ #cython: boundscheck=False, wraparound=False import numpy as np cimport numpy as np from cython.parallel cimport prange def dot(np.ndarray[np.float32_t, ndim=2] a not None, np.ndarray[np.float32_t, ndim=2] b not None, np.ndarray[np.float32_t, ndim=2] out=None): """Naive O(N**3) 2D np.dot() implementation.""" if out is None: out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype) if (a.shape[1] != b.shape[0] or out.shape[0] != a.shape[0] or out.shape[1] != b.shape[1]): raise ValueError("wrong shape") cdef Py_ssize_t i, j, k with nogil: for i in prange(a.shape[0]): for j in range(b.shape[1]): out[i,j] = 0 for k in range(a.shape[1]): out[i,j] += a[i,k] * b[k,j] return out This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,7 @@ from distutils.extension import Extension def make_ext(modname, pyxfilename): return Extension(name=modname, sources=[pyxfilename], extra_compile_args=['-fopenmp'], extra_link_args=['-fopenmp']) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,15 @@ Without `prange()` (single-threaded): python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)' 10 loops, best of 3: 119 msec per loop With `prange()` (number of threads == number of cores): python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)' 10 loops, best of 3: 69.9 msec per loop `numpy.dot()` version for comparison: python -mtimeit -s'from test_cydot import a,b,out,np' 'np.dot(a,b,out)' 100 loops, best of 3: 9.97 msec per loop This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,19 @@ import pyximport; pyximport.install() # pip install cython import numpy as np import cydot a = np.random.rand(50, 10000).astype(np.float32) b = np.random.rand(10000, 60).astype(np.float32) out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype) def test(): assert np.allclose(np.dot(a,b), cydot.dot(a,b)) out2 = out.copy() out[:] = -1; out2[:] = -2 assert np.allclose(out, -1) and np.allclose(out2, -2) np.dot(a, b, out); cydot.dot(a, b, out2) assert np.allclose(out, out2), (out,out2) if __name__=="__main__": test()