Last active
April 27, 2020 04:03
-
-
Save gngdb/611d8f180ef0f0baddaa539e29a4200e to your computer and use it in GitHub Desktop.
Least Squares in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Adapting [this excellent blog post](http://drsfenner.org/blog/2015/12/three-paths-to-least-squares-linear-regression/) to show how to write a least squares algorithm in PyTorch." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def create_data(n=5000, p=100):\n", | |
" ''' \n", | |
" n is number cases/observations/examples\n", | |
" p is number of features/attributes/variables\n", | |
" '''\n", | |
" X = torch.rand(n,p)*10.\n", | |
" coeffs = (torch.rand(p)*10.).view(p,1)\n", | |
" def f(X): return X.mm(coeffs)\n", | |
"\n", | |
" noise = torch.randn(n,1)\n", | |
" Y = f(X) + noise\n", | |
" Y = Y.view(n,1)\n", | |
" \n", | |
" return X,Y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X,Y = create_data()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Via Cholesky\n", | |
"\n", | |
"Unfortunately, there is no `torch.solve`, but there is `torch.gesv` which [solves systems of linear equations](https://pytorch.org/docs/stable/torch.html?highlight=least%20squares#torch.gesv). We have to reverse the arguments though:\n", | |
"\n", | |
"```\n", | |
"torch.gesv(B,A)\n", | |
"AX = B\n", | |
"```\n", | |
"\n", | |
"```\n", | |
"numpy.linalg.solve(a,b)\n", | |
"ax = b\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"XtX, XtY = X.permute(1,0).mm(X), X.permute(1,0).mm(Y)\n", | |
"betas_cholesky, _ = torch.gesv(XtY, XtX)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([100, 1])" | |
] | |
}, | |
"execution_count": 51, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"betas_cholesky.size()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Via QR\n", | |
"\n", | |
"PyTorch does have a [`torch.qr`](https://pytorch.org/docs/stable/torch.html?highlight=qr#torch.qr), which is good, but the blog post uses `gels`. Luckily that is also available:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"betas_qr,_ = torch.gels(Y,X)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"betas_qr = betas_qr[:X.size(1)]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Via SVD\n", | |
"\n", | |
"In PyTorch there is no `lstsq` function as used in this section. There is an SVD though, so we can put together a least squares that way." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"U, S, V = torch.svd(X)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"S_inv = (1./S).view(1,S.size(0))\n", | |
"VS = V*S_inv # inverse of diagonal is just reciprocal of diagonal\n", | |
"UtY = torch.mm(U.permute(1,0), Y)\n", | |
"betas_svd = torch.mm(VS, UtY)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Comparing Betas" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.789241313934326 4.789309501647949 4.789262294769287\n", | |
"9.147521018981934 9.14731502532959 9.147207260131836\n", | |
"5.609707355499268 5.60968017578125 5.609684467315674\n", | |
"7.863295078277588 7.86319637298584 7.863215446472168\n", | |
"7.966612815856934 7.966621398925781 7.966643333435059\n", | |
"0.9358347058296204 0.9359120726585388 0.935868501663208\n", | |
"6.460849285125732 6.460752487182617 6.46080207824707\n", | |
"2.361510753631592 2.3616764545440674 2.3616724014282227\n", | |
"3.4893341064453125 3.4897570610046387 3.489757776260376\n", | |
"0.40589192509651184 0.4058658182621002 0.4058196544647217\n", | |
"6.3950676918029785 6.395297527313232 6.395309925079346\n", | |
"4.863685131072998 4.863367557525635 4.863370418548584\n", | |
"9.596863746643066 9.596785545349121 9.596807479858398\n", | |
"3.2956292629241943 3.2953946590423584 3.295461654663086\n", | |
"0.7079331874847412 0.7079015374183655 0.7079324722290039\n", | |
"7.927335739135742 7.927225589752197 7.92726993560791\n", | |
"1.7596901655197144 1.7599283456802368 1.759899616241455\n", | |
"3.1518752574920654 3.151834487915039 3.1518335342407227\n", | |
"6.198609352111816 6.198548793792725 6.198549270629883\n", | |
"2.877772092819214 2.877809762954712 2.877819776535034\n", | |
"0.1144978478550911 0.11482840776443481 0.11482787132263184\n", | |
"1.659272313117981 1.6594208478927612 1.6593914031982422\n", | |
"8.555513381958008 8.555625915527344 8.555619239807129\n", | |
"6.668213844299316 6.668307304382324 6.668299198150635\n", | |
"4.201225280761719 4.200991153717041 4.201006889343262\n", | |
"8.882617950439453 8.882568359375 8.882552146911621\n", | |
"1.1241998672485352 1.1241655349731445 1.1241430044174194\n", | |
"9.722527503967285 9.722660064697266 9.722660064697266\n", | |
"7.96357536315918 7.96376895904541 7.963817119598389\n", | |
"4.061668395996094 4.061661243438721 4.061673164367676\n", | |
"0.4477618336677551 0.4477052688598633 0.44772469997406006\n", | |
"8.820137977600098 8.819872856140137 8.819879531860352\n", | |
"2.5416290760040283 2.5413544178009033 2.5413389205932617\n", | |
"4.627619743347168 4.627681732177734 4.627681255340576\n", | |
"0.7478161454200745 0.7477755546569824 0.7478164434432983\n", | |
"7.434598922729492 7.434365749359131 7.434381008148193\n", | |
"5.782316207885742 5.782223701477051 5.7822265625\n", | |
"7.303110122680664 7.30283260345459 7.302789688110352\n", | |
"9.179882049560547 9.180000305175781 9.180000305175781\n", | |
"3.4134645462036133 3.413600206375122 3.41367506980896\n", | |
"8.272011756896973 8.272011756896973 8.272056579589844\n", | |
"8.254677772521973 8.254998207092285 8.254950523376465\n", | |
"4.843873500823975 4.843794822692871 4.8437299728393555\n", | |
"8.814054489135742 8.814101219177246 8.814138412475586\n", | |
"3.537532329559326 3.537382125854492 3.537411689758301\n", | |
"3.009495735168457 3.009181022644043 3.0092389583587646\n", | |
"2.756869077682495 2.7569947242736816 2.7569937705993652\n", | |
"1.9333038330078125 1.93343985080719 1.9334814548492432\n", | |
"9.179861068725586 9.179898262023926 9.179922103881836\n", | |
"7.7894978523254395 7.789270401000977 7.789253234863281\n", | |
"4.843989372253418 4.843654155731201 4.843679904937744\n", | |
"6.666652202606201 6.666760444641113 6.666749477386475\n", | |
"6.028740882873535 6.028807640075684 6.0288286209106445\n", | |
"9.00769329071045 9.007808685302734 9.007761001586914\n", | |
"0.13813860714435577 0.13820038735866547 0.13817358016967773\n", | |
"9.304071426391602 9.304142951965332 9.304139137268066\n", | |
"4.384835720062256 4.384726047515869 4.384748458862305\n", | |
"4.6577630043029785 4.657683849334717 4.657710552215576\n", | |
"9.897286415100098 9.897216796875 9.897197723388672\n", | |
"1.960221290588379 1.960351586341858 1.960319995880127\n", | |
"9.509464263916016 9.509519577026367 9.509477615356445\n", | |
"1.4898333549499512 1.4899165630340576 1.4899344444274902\n", | |
"8.848724365234375 8.848979949951172 8.848989486694336\n", | |
"6.172708511352539 6.172853469848633 6.172830581665039\n", | |
"7.406793594360352 7.406737804412842 7.406713008880615\n", | |
"1.8598474264144897 1.8596168756484985 1.8596140146255493\n", | |
"7.90094518661499 7.901154041290283 7.901152610778809\n", | |
"1.0882967710494995 1.0882912874221802 1.0882902145385742\n", | |
"7.895720958709717 7.895816326141357 7.895867347717285\n", | |
"5.307201385498047 5.307127475738525 5.307128429412842\n", | |
"1.5453896522521973 1.545369029045105 1.5453554391860962\n", | |
"7.144125461578369 7.1442389488220215 7.144256591796875\n", | |
"8.263419151306152 8.26330280303955 8.263258934020996\n", | |
"4.024933815002441 4.025058269500732 4.025018215179443\n", | |
"8.888937950134277 8.889105796813965 8.889098167419434\n", | |
"4.1783928871154785 4.178523063659668 4.178526401519775\n", | |
"9.600228309631348 9.600030899047852 9.600034713745117\n", | |
"5.944829940795898 5.944741725921631 5.944738864898682\n", | |
"5.6164655685424805 5.616551876068115 5.616508483886719\n", | |
"4.597523212432861 4.597728729248047 4.597701549530029\n", | |
"7.265918254852295 7.265911102294922 7.265933036804199\n", | |
"4.980377197265625 4.980417728424072 4.98041296005249\n", | |
"8.399260520935059 8.399263381958008 8.399232864379883\n", | |
"2.6021928787231445 2.602515459060669 2.6025009155273438\n", | |
"6.948577404022217 6.9487457275390625 6.94873046875\n", | |
"0.7757428884506226 0.7756509780883789 0.7756406664848328\n", | |
"1.3978784084320068 1.3978636264801025 1.3978710174560547\n", | |
"8.347896575927734 8.34741497039795 8.347456932067871\n", | |
"3.9364943504333496 3.93650484085083 3.936490297317505\n", | |
"5.9613776206970215 5.96109676361084 5.961121082305908\n", | |
"3.430814743041992 3.430952310562134 3.4309439659118652\n", | |
"9.821599006652832 9.821892738342285 9.821885108947754\n", | |
"8.789693832397461 8.789381980895996 8.78941822052002\n", | |
"3.9870569705963135 3.9870316982269287 3.9870190620422363\n", | |
"5.418898582458496 5.419239044189453 5.41928768157959\n", | |
"6.473894119262695 6.474045753479004 6.474076747894287\n", | |
"3.898988962173462 3.898843288421631 3.8988606929779053\n", | |
"0.7752768397331238 0.7750235199928284 0.7750210762023926\n", | |
"9.980328559875488 9.980111122131348 9.98007583618164\n", | |
"0.684233546257019 0.6842225193977356 0.6841509342193604\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(betas_cholesky.size(0)):\n", | |
" print(betas_cholesky[i,0].item(),\n", | |
" betas_qr[i,0].item(),\n", | |
" betas_svd[i,0].item())" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment