Skip to content

Instantly share code, notes, and snippets.

@zed
Created March 16, 2012 18:25

Revisions

  1. zed created this gist Mar 16, 2012.
    24 changes: 24 additions & 0 deletions cydot.pyx
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,24 @@
    #cython: boundscheck=False, wraparound=False
    import numpy as np
    cimport numpy as np

    from cython.parallel cimport prange

    def dot(np.ndarray[np.float32_t, ndim=2] a not None,
    np.ndarray[np.float32_t, ndim=2] b not None,
    np.ndarray[np.float32_t, ndim=2] out=None):
    """Naive O(N**3) 2D np.dot() implementation."""
    if out is None:
    out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)
    if (a.shape[1] != b.shape[0] or
    out.shape[0] != a.shape[0] or out.shape[1] != b.shape[1]):
    raise ValueError("wrong shape")

    cdef Py_ssize_t i, j, k
    with nogil:
    for i in prange(a.shape[0]):
    for j in range(b.shape[1]):
    out[i,j] = 0
    for k in range(a.shape[1]):
    out[i,j] += a[i,k] * b[k,j]
    return out
    7 changes: 7 additions & 0 deletions cydot.pyxbld
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,7 @@
    from distutils.extension import Extension

    def make_ext(modname, pyxfilename):
    return Extension(name=modname,
    sources=[pyxfilename],
    extra_compile_args=['-fopenmp'],
    extra_link_args=['-fopenmp'])
    15 changes: 15 additions & 0 deletions results.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,15 @@
    Without `prange()` (single-threaded):

    python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)'
    10 loops, best of 3: 119 msec per loop

    With `prange()` (number of threads == number of cores):

    python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)'
    10 loops, best of 3: 69.9 msec per loop


    `numpy.dot()` version for comparison:

    python -mtimeit -s'from test_cydot import a,b,out,np' 'np.dot(a,b,out)'
    100 loops, best of 3: 9.97 msec per loop
    19 changes: 19 additions & 0 deletions test_cydot.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,19 @@
    import pyximport; pyximport.install() # pip install cython
    import numpy as np
    import cydot

    a = np.random.rand(50, 10000).astype(np.float32)
    b = np.random.rand(10000, 60).astype(np.float32)
    out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)

    def test():
    assert np.allclose(np.dot(a,b), cydot.dot(a,b))

    out2 = out.copy()
    out[:] = -1; out2[:] = -2
    assert np.allclose(out, -1) and np.allclose(out2, -2)
    np.dot(a, b, out); cydot.dot(a, b, out2)
    assert np.allclose(out, out2), (out,out2)

    if __name__=="__main__":
    test()