Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active April 27, 2020 04:03
Show Gist options
  • Save gngdb/611d8f180ef0f0baddaa539e29a4200e to your computer and use it in GitHub Desktop.
Save gngdb/611d8f180ef0f0baddaa539e29a4200e to your computer and use it in GitHub Desktop.
Least Squares in PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adapting [this excellent blog post](http://drsfenner.org/blog/2015/12/three-paths-to-least-squares-linear-regression/) to show how to write a least squares algorithm in PyTorch."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"def create_data(n=5000, p=100):\n",
" ''' \n",
" n is number cases/observations/examples\n",
" p is number of features/attributes/variables\n",
" '''\n",
" X = torch.rand(n,p)*10.\n",
" coeffs = (torch.rand(p)*10.).view(p,1)\n",
" def f(X): return X.mm(coeffs)\n",
"\n",
" noise = torch.randn(n,1)\n",
" Y = f(X) + noise\n",
" Y = Y.view(n,1)\n",
" \n",
" return X,Y"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"X,Y = create_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Via Cholesky\n",
"\n",
"Unfortunately, there is no `torch.solve`, but there is `torch.gesv` which [solves systems of linear equations](https://pytorch.org/docs/stable/torch.html?highlight=least%20squares#torch.gesv). We have to reverse the arguments though:\n",
"\n",
"```\n",
"torch.gesv(B,A)\n",
"AX = B\n",
"```\n",
"\n",
"```\n",
"numpy.linalg.solve(a,b)\n",
"ax = b\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"XtX, XtY = X.permute(1,0).mm(X), X.permute(1,0).mm(Y)\n",
"betas_cholesky, _ = torch.gesv(XtY, XtX)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([100, 1])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"betas_cholesky.size()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Via QR\n",
"\n",
"PyTorch does have a [`torch.qr`](https://pytorch.org/docs/stable/torch.html?highlight=qr#torch.qr), which is good, but the blog post uses `gels`. Luckily that is also available:"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"betas_qr,_ = torch.gels(Y,X)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"betas_qr = betas_qr[:X.size(1)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Via SVD\n",
"\n",
"In PyTorch there is no `lstsq` function as used in this section. There is an SVD though, so we can put together a least squares that way."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"U, S, V = torch.svd(X)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"S_inv = (1./S).view(1,S.size(0))\n",
"VS = V*S_inv # inverse of diagonal is just reciprocal of diagonal\n",
"UtY = torch.mm(U.permute(1,0), Y)\n",
"betas_svd = torch.mm(VS, UtY)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing Betas"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.789241313934326 4.789309501647949 4.789262294769287\n",
"9.147521018981934 9.14731502532959 9.147207260131836\n",
"5.609707355499268 5.60968017578125 5.609684467315674\n",
"7.863295078277588 7.86319637298584 7.863215446472168\n",
"7.966612815856934 7.966621398925781 7.966643333435059\n",
"0.9358347058296204 0.9359120726585388 0.935868501663208\n",
"6.460849285125732 6.460752487182617 6.46080207824707\n",
"2.361510753631592 2.3616764545440674 2.3616724014282227\n",
"3.4893341064453125 3.4897570610046387 3.489757776260376\n",
"0.40589192509651184 0.4058658182621002 0.4058196544647217\n",
"6.3950676918029785 6.395297527313232 6.395309925079346\n",
"4.863685131072998 4.863367557525635 4.863370418548584\n",
"9.596863746643066 9.596785545349121 9.596807479858398\n",
"3.2956292629241943 3.2953946590423584 3.295461654663086\n",
"0.7079331874847412 0.7079015374183655 0.7079324722290039\n",
"7.927335739135742 7.927225589752197 7.92726993560791\n",
"1.7596901655197144 1.7599283456802368 1.759899616241455\n",
"3.1518752574920654 3.151834487915039 3.1518335342407227\n",
"6.198609352111816 6.198548793792725 6.198549270629883\n",
"2.877772092819214 2.877809762954712 2.877819776535034\n",
"0.1144978478550911 0.11482840776443481 0.11482787132263184\n",
"1.659272313117981 1.6594208478927612 1.6593914031982422\n",
"8.555513381958008 8.555625915527344 8.555619239807129\n",
"6.668213844299316 6.668307304382324 6.668299198150635\n",
"4.201225280761719 4.200991153717041 4.201006889343262\n",
"8.882617950439453 8.882568359375 8.882552146911621\n",
"1.1241998672485352 1.1241655349731445 1.1241430044174194\n",
"9.722527503967285 9.722660064697266 9.722660064697266\n",
"7.96357536315918 7.96376895904541 7.963817119598389\n",
"4.061668395996094 4.061661243438721 4.061673164367676\n",
"0.4477618336677551 0.4477052688598633 0.44772469997406006\n",
"8.820137977600098 8.819872856140137 8.819879531860352\n",
"2.5416290760040283 2.5413544178009033 2.5413389205932617\n",
"4.627619743347168 4.627681732177734 4.627681255340576\n",
"0.7478161454200745 0.7477755546569824 0.7478164434432983\n",
"7.434598922729492 7.434365749359131 7.434381008148193\n",
"5.782316207885742 5.782223701477051 5.7822265625\n",
"7.303110122680664 7.30283260345459 7.302789688110352\n",
"9.179882049560547 9.180000305175781 9.180000305175781\n",
"3.4134645462036133 3.413600206375122 3.41367506980896\n",
"8.272011756896973 8.272011756896973 8.272056579589844\n",
"8.254677772521973 8.254998207092285 8.254950523376465\n",
"4.843873500823975 4.843794822692871 4.8437299728393555\n",
"8.814054489135742 8.814101219177246 8.814138412475586\n",
"3.537532329559326 3.537382125854492 3.537411689758301\n",
"3.009495735168457 3.009181022644043 3.0092389583587646\n",
"2.756869077682495 2.7569947242736816 2.7569937705993652\n",
"1.9333038330078125 1.93343985080719 1.9334814548492432\n",
"9.179861068725586 9.179898262023926 9.179922103881836\n",
"7.7894978523254395 7.789270401000977 7.789253234863281\n",
"4.843989372253418 4.843654155731201 4.843679904937744\n",
"6.666652202606201 6.666760444641113 6.666749477386475\n",
"6.028740882873535 6.028807640075684 6.0288286209106445\n",
"9.00769329071045 9.007808685302734 9.007761001586914\n",
"0.13813860714435577 0.13820038735866547 0.13817358016967773\n",
"9.304071426391602 9.304142951965332 9.304139137268066\n",
"4.384835720062256 4.384726047515869 4.384748458862305\n",
"4.6577630043029785 4.657683849334717 4.657710552215576\n",
"9.897286415100098 9.897216796875 9.897197723388672\n",
"1.960221290588379 1.960351586341858 1.960319995880127\n",
"9.509464263916016 9.509519577026367 9.509477615356445\n",
"1.4898333549499512 1.4899165630340576 1.4899344444274902\n",
"8.848724365234375 8.848979949951172 8.848989486694336\n",
"6.172708511352539 6.172853469848633 6.172830581665039\n",
"7.406793594360352 7.406737804412842 7.406713008880615\n",
"1.8598474264144897 1.8596168756484985 1.8596140146255493\n",
"7.90094518661499 7.901154041290283 7.901152610778809\n",
"1.0882967710494995 1.0882912874221802 1.0882902145385742\n",
"7.895720958709717 7.895816326141357 7.895867347717285\n",
"5.307201385498047 5.307127475738525 5.307128429412842\n",
"1.5453896522521973 1.545369029045105 1.5453554391860962\n",
"7.144125461578369 7.1442389488220215 7.144256591796875\n",
"8.263419151306152 8.26330280303955 8.263258934020996\n",
"4.024933815002441 4.025058269500732 4.025018215179443\n",
"8.888937950134277 8.889105796813965 8.889098167419434\n",
"4.1783928871154785 4.178523063659668 4.178526401519775\n",
"9.600228309631348 9.600030899047852 9.600034713745117\n",
"5.944829940795898 5.944741725921631 5.944738864898682\n",
"5.6164655685424805 5.616551876068115 5.616508483886719\n",
"4.597523212432861 4.597728729248047 4.597701549530029\n",
"7.265918254852295 7.265911102294922 7.265933036804199\n",
"4.980377197265625 4.980417728424072 4.98041296005249\n",
"8.399260520935059 8.399263381958008 8.399232864379883\n",
"2.6021928787231445 2.602515459060669 2.6025009155273438\n",
"6.948577404022217 6.9487457275390625 6.94873046875\n",
"0.7757428884506226 0.7756509780883789 0.7756406664848328\n",
"1.3978784084320068 1.3978636264801025 1.3978710174560547\n",
"8.347896575927734 8.34741497039795 8.347456932067871\n",
"3.9364943504333496 3.93650484085083 3.936490297317505\n",
"5.9613776206970215 5.96109676361084 5.961121082305908\n",
"3.430814743041992 3.430952310562134 3.4309439659118652\n",
"9.821599006652832 9.821892738342285 9.821885108947754\n",
"8.789693832397461 8.789381980895996 8.78941822052002\n",
"3.9870569705963135 3.9870316982269287 3.9870190620422363\n",
"5.418898582458496 5.419239044189453 5.41928768157959\n",
"6.473894119262695 6.474045753479004 6.474076747894287\n",
"3.898988962173462 3.898843288421631 3.8988606929779053\n",
"0.7752768397331238 0.7750235199928284 0.7750210762023926\n",
"9.980328559875488 9.980111122131348 9.98007583618164\n",
"0.684233546257019 0.6842225193977356 0.6841509342193604\n"
]
}
],
"source": [
"for i in range(betas_cholesky.size(0)):\n",
" print(betas_cholesky[i,0].item(),\n",
" betas_qr[i,0].item(),\n",
" betas_svd[i,0].item())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment