Created
November 14, 2019 08:51
-
-
Save dizcza/003d65f4c2d9ee86000ce85ea27e8df2 to your computer and use it in GitHub Desktop.
Cross-correlation of two 1-dimensional sparse CSR matrices.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.sparse | |
def correlate_sparse(matrix1, matrix2, mode='valid'): | |
""" | |
Cross-correlation of two 1-dimensional sparse CSR matrices. | |
Drop-in replacement for `np.correlate`. | |
Parameters | |
---------- | |
matrix1, matrix2 : scipy.sparse.csr.csr_matrix | |
Input sequences. | |
mode : {'valid', 'same', 'full', list}, optional | |
Refer to the `np.convolve` docstring. | |
Returns | |
------- | |
cross_corr : np.ndarray | |
Discrete cross-correlation of `matrix1` and `matrix2`. | |
""" | |
if matrix1.shape[0] != 1 or matrix2.shape[0] != 1: | |
raise ValueError("Input matrices should have 1 row") | |
m1_size = matrix1.shape[1] | |
m2_size = matrix2.shape[1] | |
if m1_size < m2_size: | |
return correlate_sparse(matrix2, matrix1, mode=mode)[::-1] | |
# compute left and right shifts | |
if mode == 'full': | |
left = -m2_size + 1 | |
right = m1_size - 1 | |
elif mode == 'same': | |
left = -(m2_size // 2) | |
right = m1_size - 1 + left | |
elif mode == 'valid': | |
left = 0 | |
right = m1_size - m2_size | |
else: | |
left, right = mode | |
matrix1 = matrix1[max(left, 0): min(right + 1, m1_size)] | |
nrows = right - left + 1 | |
row_wise_shifts = np.expand_dims(np.arange(left, left + nrows), axis=1) | |
index_shifts = np.tile(matrix2.indices, (nrows, 1)) + row_wise_shifts | |
mask_valid = (index_shifts >= 0) & (index_shifts < m1_size) | |
indptr = np.r_[0, mask_valid.sum(axis=1).cumsum()] | |
indices = index_shifts[mask_valid] | |
mask_col_ids = mask_valid.nonzero()[1] | |
data = matrix2.data[mask_col_ids] | |
matrix2_toeplitz = scipy.sparse.csr_matrix((data, indices, indptr), | |
shape=(nrows, m1_size), | |
copy=False) | |
cross_corr = matrix1.multiply(matrix2_toeplitz).sum(axis=1) | |
cross_corr = np.asarray(cross_corr)[:, 0] | |
return cross_corr |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment