Created
July 6, 2023 04:46
-
-
Save mcdlee/567d8b95759914f3596cb5ffb224b9a8 to your computer and use it in GitHub Desktop.
Loss matrix of EMD in Python optimal transports
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "347a3f1f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from matplotlib import pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "30265d7f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0, 1, 2, 3, 4, 5, 6, 7]])" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = np.array(range(8)).reshape(1,-1)\n", | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "f75b5486", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [0, 1, 2, 3, 4, 5, 6, 7]])" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = np.repeat(x,8, axis=0)\n", | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "13eea9d8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0],\n", | |
" [1],\n", | |
" [2],\n", | |
" [3],\n", | |
" [4],\n", | |
" [5],\n", | |
" [6],\n", | |
" [7]])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y = np.array(range(8)).reshape(-1,1)\n", | |
"y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "5271daf4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [1, 1, 1, 1, 1, 1, 1, 1],\n", | |
" [2, 2, 2, 2, 2, 2, 2, 2],\n", | |
" [3, 3, 3, 3, 3, 3, 3, 3],\n", | |
" [4, 4, 4, 4, 4, 4, 4, 4],\n", | |
" [5, 5, 5, 5, 5, 5, 5, 5],\n", | |
" [6, 6, 6, 6, 6, 6, 6, 6],\n", | |
" [7, 7, 7, 7, 7, 7, 7, 7]])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y = np.repeat(y,8, axis=1)\n", | |
"y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "4d12443f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"c = np.empty((8,8), dtype=np.complex128) #c stands for coordinate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "654fa579", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"c.real = x\n", | |
"c.imag = y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "d570c6ca", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],\n", | |
" [0.+1.j, 1.+1.j, 2.+1.j, 3.+1.j, 4.+1.j, 5.+1.j, 6.+1.j, 7.+1.j],\n", | |
" [0.+2.j, 1.+2.j, 2.+2.j, 3.+2.j, 4.+2.j, 5.+2.j, 6.+2.j, 7.+2.j],\n", | |
" [0.+3.j, 1.+3.j, 2.+3.j, 3.+3.j, 4.+3.j, 5.+3.j, 6.+3.j, 7.+3.j],\n", | |
" [0.+4.j, 1.+4.j, 2.+4.j, 3.+4.j, 4.+4.j, 5.+4.j, 6.+4.j, 7.+4.j],\n", | |
" [0.+5.j, 1.+5.j, 2.+5.j, 3.+5.j, 4.+5.j, 5.+5.j, 6.+5.j, 7.+5.j],\n", | |
" [0.+6.j, 1.+6.j, 2.+6.j, 3.+6.j, 4.+6.j, 5.+6.j, 6.+6.j, 7.+6.j],\n", | |
" [0.+7.j, 1.+7.j, 2.+7.j, 3.+7.j, 4.+7.j, 5.+7.j, 6.+7.j, 7.+7.j]])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"c" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "012d465c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dist = np.abs(np.subtract.outer(c,c)) #歐氏距離 (Euclidean distance),可採用其他距離,例如 Chebyshev distance" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "7ee85dc8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(8, 8, 8, 8)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dist.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "f835baf6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"5.0" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dist[3,2,7,5] # 可以解讀成是 (3,2) 和 (7,5)的距離" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "addc7478", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.4142135623730951" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dist[0,0,1,1] #根號 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "3b3f1745", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(dist[3,5,:,:], \"binary\") #即 (3,5) 與各座標的距離\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "0c6cd7e5", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j,\n", | |
" 0.+1.j, 1.+1.j, 2.+1.j, 3.+1.j, 4.+1.j, 5.+1.j, 6.+1.j, 7.+1.j,\n", | |
" 0.+2.j, 1.+2.j, 2.+2.j, 3.+2.j, 4.+2.j, 5.+2.j, 6.+2.j, 7.+2.j,\n", | |
" 0.+3.j, 1.+3.j, 2.+3.j, 3.+3.j, 4.+3.j, 5.+3.j, 6.+3.j, 7.+3.j,\n", | |
" 0.+4.j, 1.+4.j, 2.+4.j, 3.+4.j, 4.+4.j, 5.+4.j, 6.+4.j, 7.+4.j,\n", | |
" 0.+5.j, 1.+5.j, 2.+5.j, 3.+5.j, 4.+5.j, 5.+5.j, 6.+5.j, 7.+5.j,\n", | |
" 0.+6.j, 1.+6.j, 2.+6.j, 3.+6.j, 4.+6.j, 5.+6.j, 6.+6.j, 7.+6.j,\n", | |
" 0.+7.j, 1.+7.j, 2.+7.j, 3.+7.j, 4.+7.j, 5.+7.j, 6.+7.j, 7.+7.j])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"c.ravel() #此為下圖的x軸和y軸" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "824ccf50", | |
"metadata": {}, | |
"source": [ | |
"Python optimal transports 的 ot.emd() 及 ot.emd2() 的 M 可以用這種方式生出來。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "894ba1df", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"M = dist.reshape(64,64)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "0a43500e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(dist.reshape(64,64),\"binary\")\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5ddc2d90", | |
"metadata": {}, | |
"source": [ | |
"# 實驗一下" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "05ead127", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"A = np.zeros((8,8))\n", | |
"B = np.zeros((8,8))\n", | |
"\n", | |
"A[2,3]=1\n", | |
"B[5,7]=1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "554e7dc5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import ot" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "07ccf212", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"5.0" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ot.emd2(A.ravel(),B.ravel(),M)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "2e730f0d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.image.AxesImage at 0x7fc8d3e63b80>" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(ot.emd(A.ravel(),B.ravel(),M))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "baba7763", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment