Created
June 13, 2017 13:10
-
-
Save mrocklin/017f78ce52d265b6d72828fb29e5619c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Tensor Decompositions with NumPy and Dask.array\n", | |
"===========================\n", | |
"\n", | |
"We implement a very simple algorithm to compute the PARAFAC/CANDECOMP tensor decomposition using alternating least squares.\n", | |
"\n", | |
"We follow section 4.1 from this paper: https://www.cs.cmu.edu/~pmuthuku/mlsp_page/lectures/Parafac.pdf\n", | |
"\n", | |
"[PARAFAC/CANDECOMP](https://en.wikipedia.org/wiki/Tensor_rank_decomposition) is an array decomposition that generalizes SVD/PCA to higher dimensions. It is commonly used when trying to find low-rank structure in observational data in multiple dimensions. \n", | |
"\n", | |
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/5/52/Collaborative_filtering.gif\" \n", | |
" align=\"right\"\n", | |
" width=\"40%\">\n", | |
"\n", | |
"In a typical two dimensional setting we might call this collaborative filtering, and is an approach used by companies like Amazon and Netflix to help recommend products to users based on the habits of similar users. However nothing about this problem is specifically restricted to two dimensions. We might consider the following applications:\n", | |
"\n", | |
"1. What customers bought what products at what times of day or year\n", | |
"2. Which computers talked to which other computers over which ports under which user\n", | |
"3. Outcomes of treatments affecting patients with which dieseases\n", | |
"\n", | |
"When we look around for algorithms to generalize SVD/PCA to multiple dimensions we come across a family of algorithms called tensor decompositions. This notebook implements the simplest such algorithm in a common and simple way using NumPy and then scales it out naively using Dask arrays." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Alternating Least Squares" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from functools import reduce\n", | |
"import operator\n", | |
"import dask\n", | |
"\n", | |
"\n", | |
"def outer_product(*Xs):\n", | |
" \"\"\" Outer/tensor product of multiple 2d arrays\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" *Xs: arrays of size (k, n)\n", | |
" All inputs must have the same first dimension but may have varying\n", | |
" second dimensions\n", | |
" \"\"\"\n", | |
" n = len(Xs)\n", | |
" indexes = [(slice(None, None),) +\n", | |
" (None,) * i +\n", | |
" (slice(None, None),) +\n", | |
" (None,) * (n - i - 1) for i in range(len(Xs))]\n", | |
" Ys = [X[ind] for X, ind in zip(Xs, indexes)]\n", | |
" Ys = sorted(Ys, key=lambda y: y.nbytes) # smaller outer products first\n", | |
" return reduce(operator.mul, Ys)\n", | |
"\n", | |
"\n", | |
"def parafac_als(X, n_factors, n_iter=100):\n", | |
" \"\"\" Parafac tensor decomposition\n", | |
"\n", | |
" This implements the basic algorithm in section 4.1 of this paper\n", | |
"\n", | |
" https://www.cs.cmu.edu/~pmuthuku/mlsp_page/lectures/Parafac.pdf\n", | |
" \"\"\"\n", | |
" # Randomly initialize factors\n", | |
" factors = [np.random.random((n_factors, X.shape[i])) for i in range(0, X.ndim)]\n", | |
" \n", | |
" # Solve\n", | |
" for itr in range(n_iter):\n", | |
" for i in range(X.ndim):\n", | |
" not_i = tuple(j for j in range(X.ndim) if j != i)\n", | |
" Xp = X.transpose((i,) + not_i)\n", | |
" Xp = Xp.reshape((Xp.shape[0], np.prod(Xp.shape[1:])))\n", | |
" Z = outer_product(*[factors[j] for j in not_i])\n", | |
" Z = Z.reshape((Z.shape[0], np.prod(Z.shape[1:])))\n", | |
"\n", | |
" factor, residuals, rank, s = np.linalg.lstsq(Z.T, Xp.T)\n", | |
" factors[i] = factor\n", | |
" return factors" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Demonstrate accuracy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[[[1, 1, 0, 1, 1],\n", | |
" [1, 1, 1, 2, 1]],\n", | |
"\n", | |
" [[1, 1, 0, 1, 1],\n", | |
" [1, 1, 0, 1, 1]]],\n", | |
"\n", | |
"\n", | |
" [[[0, 0, 0, 0, 0],\n", | |
" [0, 0, 1, 1, 0]],\n", | |
"\n", | |
" [[0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0]]],\n", | |
"\n", | |
"\n", | |
" [[[1, 1, 0, 1, 1],\n", | |
" [1, 1, 1, 2, 1]],\n", | |
"\n", | |
" [[1, 1, 0, 1, 1],\n", | |
" [1, 1, 0, 1, 1]]],\n", | |
"\n", | |
"\n", | |
" [[[0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0]],\n", | |
"\n", | |
" [[0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0]]],\n", | |
"\n", | |
"\n", | |
" [[[1, 1, 0, 1, 1],\n", | |
" [1, 1, 0, 1, 1]],\n", | |
"\n", | |
" [[1, 1, 0, 1, 1],\n", | |
" [1, 1, 0, 1, 1]]],\n", | |
"\n", | |
"\n", | |
" [[[0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0]],\n", | |
"\n", | |
" [[0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0]]]])" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"true_factors = [[[1, 1, 1, 0, 0, 0],\n", | |
" [1, 0, 1, 0, 1, 0]],\n", | |
" [[1, 0],\n", | |
" [1, 1]],\n", | |
" [[0, 1],\n", | |
" [1, 1]],\n", | |
" [[0, 0, 1, 1, 0],\n", | |
" [1, 1, 0, 1, 1]]]\n", | |
"true_factors = [np.array(x).reshape((2, len(x[0]))) for x in true_factors]\n", | |
"\n", | |
"X = outer_product(*true_factors).sum(axis=0)\n", | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[[[ 1.00001470e+00, 1.00001470e+00, -1.14710333e-04,\n", | |
" 9.99899985e-01, 1.00001470e+00],\n", | |
" [ 1.00004331e+00, 1.00004331e+00, 1.00028312e+00,\n", | |
" 2.00032642e+00, 1.00004331e+00]],\n", | |
"\n", | |
" [[ 9.99849453e-01, 9.99849453e-01, -5.35488732e-04,\n", | |
" 9.99313964e-01, 9.99849453e-01],\n", | |
" [ 1.00004716e+00, 1.00004716e+00, 1.26153790e-04,\n", | |
" 1.00017331e+00, 1.00004716e+00]]],\n", | |
"\n", | |
"\n", | |
" [[[ 2.49442423e-04, 2.49442423e-04, 9.35537112e-04,\n", | |
" 1.18497953e-03, 2.49442423e-04],\n", | |
" [ -8.92432553e-05, -8.92432553e-05, 9.99421862e-01,\n", | |
" 9.99332618e-01, -8.92432553e-05]],\n", | |
"\n", | |
" [[ -4.84114017e-04, -4.84114017e-04, -2.04090234e-03,\n", | |
" -2.52501636e-03, -4.84114017e-04],\n", | |
" [ 1.88343989e-04, 1.88343989e-04, 9.63345064e-04,\n", | |
" 1.15168905e-03, 1.88343989e-04]]],\n", | |
"\n", | |
"\n", | |
" [[[ 1.00001470e+00, 1.00001470e+00, -1.14710333e-04,\n", | |
" 9.99899985e-01, 1.00001470e+00],\n", | |
" [ 1.00004331e+00, 1.00004331e+00, 1.00028312e+00,\n", | |
" 2.00032642e+00, 1.00004331e+00]],\n", | |
"\n", | |
" [[ 9.99849453e-01, 9.99849453e-01, -5.35488732e-04,\n", | |
" 9.99313964e-01, 9.99849453e-01],\n", | |
" [ 1.00004716e+00, 1.00004716e+00, 1.26153790e-04,\n", | |
" 1.00017331e+00, 1.00004716e+00]]],\n", | |
"\n", | |
"\n", | |
" [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00],\n", | |
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00]],\n", | |
"\n", | |
" [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00],\n", | |
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00]]],\n", | |
"\n", | |
"\n", | |
" [[[ 9.99765253e-01, 9.99765253e-01, -1.05024745e-03,\n", | |
" 9.98715005e-01, 9.99765253e-01],\n", | |
" [ 1.00013255e+00, 1.00013255e+00, 8.61254208e-04,\n", | |
" 1.00099380e+00, 1.00013255e+00]],\n", | |
"\n", | |
" [[ 1.00033357e+00, 1.00033357e+00, 1.50541361e-03,\n", | |
" 1.00183898e+00, 1.00033357e+00],\n", | |
" [ 9.99858812e-01, 9.99858812e-01, -8.37191274e-04,\n", | |
" 9.99021621e-01, 9.99858812e-01]]],\n", | |
"\n", | |
"\n", | |
" [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00],\n", | |
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00]],\n", | |
"\n", | |
" [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00],\n", | |
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", | |
" 0.00000000e+00, 0.00000000e+00]]]])" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"factors = parafac_als(X, 3, 10)\n", | |
"\n", | |
"computed = outer_product(*factors).sum(axis=0)\n", | |
"computed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[[[ 0., 0., -0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]],\n", | |
"\n", | |
" [[ 0., 0., -0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]]],\n", | |
"\n", | |
"\n", | |
" [[[ 0., 0., 0., 0., 0.],\n", | |
" [-0., -0., 0., 0., -0.]],\n", | |
"\n", | |
" [[-0., -0., -0., -0., -0.],\n", | |
" [ 0., 0., 0., 0., 0.]]],\n", | |
"\n", | |
"\n", | |
" [[[ 0., 0., -0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]],\n", | |
"\n", | |
" [[ 0., 0., -0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]]],\n", | |
"\n", | |
"\n", | |
" [[[ 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]],\n", | |
"\n", | |
" [[ 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]]],\n", | |
"\n", | |
"\n", | |
" [[[ 0., 0., -0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]],\n", | |
"\n", | |
" [[ 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., -0., 0., 0.]]],\n", | |
"\n", | |
"\n", | |
" [[[ 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]],\n", | |
"\n", | |
" [[ 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0.]]]])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"computed.round(0) - X" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Scale with Dask.array\n", | |
"\n", | |
"The ALS algorithm above is largely reshapings, scalar multiplication, matrix multiplication, and least squares computations (A QR decomposition followed by a triangular solve) all of which Dask.array can do easily. Here we adapt our code from before to work on dask.arrays. We include the original unchanged lines to show how easy it is to adapt." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import dask.array as da\n", | |
"\n", | |
"def parafac_als(X, n_factors, n_iter=100):\n", | |
" \"\"\" Parafac tensor decomposition\n", | |
"\n", | |
" This implements the basic algorithm in section 4.1 of this paper\n", | |
"\n", | |
" https://www.cs.cmu.edu/~pmuthuku/mlsp_page/lectures/Parafac.pdf\n", | |
" \"\"\"\n", | |
" # Randomly initialize factors\n", | |
" # factors = [np.random.random((n_factors, X.shape[i])) for i in range(0, X.ndim)]\n", | |
" factors = [da.random.random((n_factors, X.shape[i]), \n", | |
" chunks=(None, X.chunks[i])) \n", | |
" for i in range(0, X.ndim)]\n", | |
" \n", | |
" # Solve\n", | |
" for itr in range(n_iter):\n", | |
" for i in range(X.ndim):\n", | |
" not_i = tuple(j for j in range(X.ndim) if j != i)\n", | |
" Xp = X.transpose((i,) + not_i)\n", | |
" Xp = Xp.reshape((Xp.shape[0], np.prod(Xp.shape[1:])))\n", | |
" Z = outer_product(*[factors[j] for j in not_i])\n", | |
" Z = Z.reshape((Z.shape[0], np.prod(Z.shape[1:])))\n", | |
"\n", | |
" # factor, residuals, rank, s = np.linalg.lstsq(Z.T, Xp.T)\n", | |
" factor, residuals, rank, s = da.linalg.lstsq(Z.T, Xp.T)\n", | |
" factors[i] = factor\n", | |
" return factors" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Visualize result\n", | |
"\n", | |
"To show the complexity of the algorithm we visualize a single round of chunked ALS as a task graph" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment