Created
November 11, 2022 18:58
-
-
Save sourabh2k15/7f363358801ec1e8ac770c61c7de5203 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
JAX training steps progress : | |
0) loss = 32.684505462646484 grad_norm = 46.189605712890625 | |
1) loss = 32.684505462646484 grad_norm = 46.189605712890625 | |
2) loss = 32.2853889465332 grad_norm = 47.51691818237305 | |
3) loss = 31.46957778930664 grad_norm = 49.439979553222656 | |
4) loss = 30.260791778564453 grad_norm = 53.63009262084961 | |
5) loss = 28.776248931884766 grad_norm = 58.055015563964844 | |
6) loss = 27.165119171142578 grad_norm = 73.13194274902344 | |
7) loss = 25.350372314453125 grad_norm = 83.24612426757812 | |
8) loss = 23.403196334838867 grad_norm = 82.15689086914062 | |
9) loss = 21.36658477783203 grad_norm = 66.725830078125 | |
10) loss = 19.72756576538086 grad_norm = 52.60395812988281 | |
11) loss = 18.716609954833984 grad_norm = 13.802923202514648 | |
12) loss = 19.0362491607666 grad_norm = 39.36472702026367 | |
13) loss = 19.368528366088867 grad_norm = 49.91731262207031 | |
14) loss = 19.246591567993164 grad_norm = 49.75506591796875 | |
15) loss = 18.68264389038086 grad_norm = 42.25536346435547 | |
16) loss = 17.978076934814453 grad_norm = 27.26990509033203 | |
17) loss = 17.526891708374023 grad_norm = 17.5865421295166 | |
18) loss = 17.449459075927734 grad_norm = 32.81880187988281 | |
19) loss = 17.382862091064453 grad_norm = 43.94654083251953 | |
20) loss = 17.09641456604004 grad_norm = 47.34761428833008 | |
21) loss = 16.57782745361328 grad_norm = 45.367431640625 | |
22) loss = 15.897388458251953 grad_norm = 38.91270065307617 | |
23) loss = 15.132328033447266 grad_norm = 30.798185348510742 | |
24) loss = 14.477874755859375 grad_norm = 21.245559692382812 | |
25) loss = 13.980663299560547 grad_norm = 21.04851531982422 | |
26) loss = 13.48892593383789 grad_norm = 22.92694091796875 | |
27) loss = 12.980422973632812 grad_norm = 23.357032775878906 | |
28) loss = 12.4077730178833 grad_norm = 23.73438262939453 | |
29) loss = 11.470919609069824 grad_norm = 44.06322479248047 | |
30) loss = 9.974434852600098 grad_norm = 62.56569290161133 | |
31) loss = 8.121500015258789 grad_norm = 32.62065505981445 | |
32) loss = 7.7279953956604 grad_norm = 12.60580062866211 | |
33) loss = 8.078917503356934 grad_norm = 20.27943229675293 | |
34) loss = 8.428027153015137 grad_norm = 22.718721389770508 | |
35) loss = 8.606433868408203 grad_norm = 23.325977325439453 | |
36) loss = 8.611522674560547 grad_norm = 23.37453269958496 | |
37) loss = 8.46786880493164 grad_norm = 23.161762237548828 | |
38) loss = 8.200492858886719 grad_norm = 22.680795669555664 | |
39) loss = 7.833049774169922 grad_norm = 21.734981536865234 | |
40) loss = 7.39454984664917 grad_norm = 19.83635711669922 | |
41) loss = 6.936933994293213 grad_norm = 15.858280181884766 | |
42) loss = 6.5785231590271 grad_norm = 7.353372097015381 | |
43) loss = 6.590823650360107 grad_norm = 10.804391860961914 | |
44) loss = 6.976533889770508 grad_norm = 28.068567276000977 | |
45) loss = 7.364764213562012 grad_norm = 39.22932052612305 | |
46) loss = 7.52872896194458 grad_norm = 43.379493713378906 | |
47) loss = 7.444374084472656 grad_norm = 42.099246978759766 | |
48) loss = 7.193683624267578 grad_norm = 37.185791015625 | |
49) loss = 6.878286361694336 grad_norm = 30.11634635925293 | |
50) loss = 6.578378677368164 grad_norm = 22.0616512298584 | |
51) loss = 6.342897891998291 grad_norm = 13.930447578430176 | |
52) loss = 6.193991184234619 grad_norm = 6.407179355621338 | |
53) loss = 6.135108470916748 grad_norm = 1.4418846368789673 | |
54) loss = 6.147031784057617 grad_norm = 5.338115692138672 | |
55) loss = 6.180387496948242 grad_norm = 8.108418464660645 | |
56) loss = 6.207060813903809 grad_norm = 9.701889038085938 | |
57) loss = 6.215665817260742 grad_norm = 10.471610069274902 | |
58) loss = 6.202969074249268 grad_norm = 10.607603073120117 | |
59) loss = 6.170167922973633 grad_norm = 10.189746856689453 | |
60) loss = 6.121431827545166 grad_norm = 9.220542907714844 | |
61) loss = 6.063607692718506 grad_norm = 7.641906261444092 | |
62) loss = 6.006581783294678 grad_norm = 5.349120140075684 | |
63) loss = 5.963957786560059 grad_norm = 2.2978594303131104 | |
64) loss = 5.951851844787598 grad_norm = 2.2282962799072266 | |
65) loss = 5.971628665924072 grad_norm = 5.987220764160156 | |
66) loss = 6.005858421325684 grad_norm = 9.195704460144043 | |
67) loss = 6.038926124572754 grad_norm = 11.548829078674316 | |
68) loss = 6.059482097625732 grad_norm = 12.938092231750488 | |
69) loss = 6.061285972595215 grad_norm = 13.334378242492676 | |
70) loss = 6.043198108673096 grad_norm = 12.766958236694336 | |
71) loss = 6.008614540100098 grad_norm = 11.311990737915039 | |
72) loss = 5.964560508728027 grad_norm = 9.086621284484863 | |
73) loss = 5.920508861541748 grad_norm = 6.246153354644775 | |
74) loss = 5.887033462524414 grad_norm = 2.993448495864868 | |
75) loss = 5.874838352203369 grad_norm = 0.8009199500083923 | |
76) loss = 5.884912014007568 grad_norm = 3.272224187850952 | |
77) loss = 5.907811641693115 grad_norm = 5.344062328338623 | |
78) loss = 5.9315948486328125 grad_norm = 6.702033996582031 | |
79) loss = 5.949487686157227 grad_norm = 7.477400779724121 | |
80) loss = 5.958114147186279 grad_norm = 7.759702205657959 | |
81) loss = 5.956571578979492 grad_norm = 7.596379280090332 | |
82) loss = 5.945903301239014 grad_norm = 6.9987993240356445 | |
83) loss = 5.928875923156738 grad_norm = 5.9481329917907715 | |
84) loss = 5.909940242767334 grad_norm = 4.404060363769531 | |
85) loss = 5.895603179931641 grad_norm = 2.398494243621826 | |
86) loss = 5.892603874206543 grad_norm = 0.8986305594444275 | |
87) loss = 5.902677536010742 grad_norm = 2.8863940238952637 | |
88) loss = 5.920483589172363 grad_norm = 5.06048059463501 | |
89) loss = 5.935702800750732 grad_norm = 6.575645446777344 | |
90) loss = 5.942634105682373 grad_norm = 7.361220836639404 | |
91) loss = 5.9386491775512695 grad_norm = 7.407079219818115 | |
92) loss = 5.924259185791016 grad_norm = 6.745035171508789 | |
93) loss = 5.902726173400879 grad_norm = 5.442744731903076 | |
94) loss = 5.87941837310791 grad_norm = 3.6016411781311035 | |
95) loss = 5.86182165145874 grad_norm = 1.5386890172958374 | |
96) loss = 5.854459762573242 grad_norm = 0.8638936281204224 | |
97) loss = 5.856882095336914 grad_norm = 2.5303218364715576 | |
98) loss = 5.86497688293457 grad_norm = 3.910663366317749 | |
99) loss = 5.873371601104736 grad_norm = 4.834494590759277 | |
JAX program execution took 140.81151580810547 seconds | |
PyTorch training steps progress : | |
I1111 18:56:44.580436 140143604311872 torch_e2e.py:143] 0) loss = 32.799583435058594, grad_norm = 5.480198383331299 | |
I1111 18:56:44.602942 140143604311872 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:44.602947 140194652907328 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:44.602951 140007166039872 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:44.604685 140110968313664 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:44.604720 140128537937728 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:44.604736 139625133123392 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:44.604791 140694398179136 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:44.604785 140680726771520 distributed.py:995] Reducer buckets have been rebuilt in this iteration. | |
I1111 18:56:45.760965 140143604311872 torch_e2e.py:143] 1) loss = 32.81392288208008, grad_norm = 5.564063549041748 | |
I1111 18:56:46.859724 140143604311872 torch_e2e.py:143] 2) loss = 32.78697204589844, grad_norm = 5.491509914398193 | |
I1111 18:56:47.799092 140143604311872 torch_e2e.py:143] 3) loss = 32.7858772277832, grad_norm = 5.60224723815918 | |
I1111 18:56:48.779518 140143604311872 torch_e2e.py:143] 4) loss = 32.76042175292969, grad_norm = 5.729595184326172 | |
I1111 18:56:49.569550 140143604311872 torch_e2e.py:143] 5) loss = 32.7255973815918, grad_norm = 5.890851020812988 | |
I1111 18:56:50.360571 140143604311872 torch_e2e.py:143] 6) loss = 32.686431884765625, grad_norm = 6.0222039222717285 | |
I1111 18:56:51.153121 140143604311872 torch_e2e.py:143] 7) loss = 32.63671875, grad_norm = 6.230602264404297 | |
I1111 18:56:51.946784 140143604311872 torch_e2e.py:143] 8) loss = 32.585205078125, grad_norm = 6.409563064575195 | |
I1111 18:56:52.741659 140143604311872 torch_e2e.py:143] 9) loss = 32.50416564941406, grad_norm = 6.716286659240723 | |
I1111 18:56:53.535360 140143604311872 torch_e2e.py:143] 10) loss = 32.42424774169922, grad_norm = 7.101415634155273 | |
I1111 18:56:54.330379 140143604311872 torch_e2e.py:143] 11) loss = 32.317264556884766, grad_norm = 7.50632381439209 | |
I1111 18:56:55.125588 140143604311872 torch_e2e.py:143] 12) loss = 32.208160400390625, grad_norm = 7.92339563369751 | |
I1111 18:56:55.921339 140143604311872 torch_e2e.py:143] 13) loss = 32.077720642089844, grad_norm = 8.465384483337402 | |
I1111 18:56:56.714532 140143604311872 torch_e2e.py:143] 14) loss = 31.917701721191406, grad_norm = 9.081777572631836 | |
I1111 18:56:57.510020 140143604311872 torch_e2e.py:143] 15) loss = 31.73249053955078, grad_norm = 9.849095344543457 | |
I1111 18:56:58.305579 140143604311872 torch_e2e.py:143] 16) loss = 31.553165435791016, grad_norm = 10.648540496826172 | |
I1111 18:56:59.098185 140143604311872 torch_e2e.py:143] 17) loss = 31.305742263793945, grad_norm = 11.513188362121582 | |
I1111 18:56:59.893549 140143604311872 torch_e2e.py:143] 18) loss = 31.056217193603516, grad_norm = 12.512486457824707 | |
I1111 18:57:00.690072 140143604311872 torch_e2e.py:143] 19) loss = 30.75790786743164, grad_norm = 13.616098403930664 | |
I1111 18:57:01.483093 140143604311872 torch_e2e.py:143] 20) loss = 30.423843383789062, grad_norm = 14.62498950958252 | |
I1111 18:57:02.277024 140143604311872 torch_e2e.py:143] 21) loss = 30.038217544555664, grad_norm = 15.507715225219727 | |
I1111 18:57:03.072989 140143604311872 torch_e2e.py:143] 22) loss = 29.629596710205078, grad_norm = 16.433927536010742 | |
I1111 18:57:03.873271 140143604311872 torch_e2e.py:143] 23) loss = 29.170137405395508, grad_norm = 17.48212432861328 | |
I1111 18:57:04.667958 140143604311872 torch_e2e.py:143] 24) loss = 28.667102813720703, grad_norm = 18.283262252807617 | |
I1111 18:57:05.461400 140143604311872 torch_e2e.py:143] 25) loss = 28.125272750854492, grad_norm = 19.037086486816406 | |
I1111 18:57:06.256874 140143604311872 torch_e2e.py:143] 26) loss = 27.546295166015625, grad_norm = 19.721418380737305 | |
I1111 18:57:07.051243 140143604311872 torch_e2e.py:143] 27) loss = 26.921031951904297, grad_norm = 20.364442825317383 | |
I1111 18:57:07.846579 140143604311872 torch_e2e.py:143] 28) loss = 26.258169174194336, grad_norm = 20.912199020385742 | |
I1111 18:57:08.642325 140143604311872 torch_e2e.py:143] 29) loss = 25.564598083496094, grad_norm = 21.407268524169922 | |
I1111 18:57:09.437666 140143604311872 torch_e2e.py:143] 30) loss = 24.84899139404297, grad_norm = 21.567232131958008 | |
I1111 18:57:10.234629 140143604311872 torch_e2e.py:143] 31) loss = 24.078826904296875, grad_norm = 21.630910873413086 | |
I1111 18:57:11.031373 140143604311872 torch_e2e.py:143] 32) loss = 23.287574768066406, grad_norm = 21.588590621948242 | |
I1111 18:57:11.826267 140143604311872 torch_e2e.py:143] 33) loss = 22.48159408569336, grad_norm = 21.48728370666504 | |
I1111 18:57:12.623108 140143604311872 torch_e2e.py:143] 34) loss = 21.630659103393555, grad_norm = 21.31056022644043 | |
I1111 18:57:13.418298 140143604311872 torch_e2e.py:143] 35) loss = 20.80036163330078, grad_norm = 21.08628273010254 | |
I1111 18:57:14.212611 140143604311872 torch_e2e.py:143] 36) loss = 19.9207820892334, grad_norm = 20.772138595581055 | |
I1111 18:57:15.009569 140143604311872 torch_e2e.py:143] 37) loss = 19.082284927368164, grad_norm = 20.419795989990234 | |
I1111 18:57:15.806542 140143604311872 torch_e2e.py:143] 38) loss = 18.18767738342285, grad_norm = 19.978591918945312 | |
I1111 18:57:16.604197 140143604311872 torch_e2e.py:143] 39) loss = 17.313161849975586, grad_norm = 19.488441467285156 | |
I1111 18:57:17.398409 140143604311872 torch_e2e.py:143] 40) loss = 16.46230125427246, grad_norm = 18.933177947998047 | |
I1111 18:57:18.195473 140143604311872 torch_e2e.py:143] 41) loss = 15.606142044067383, grad_norm = 18.297391891479492 | |
I1111 18:57:18.992078 140143604311872 torch_e2e.py:143] 42) loss = 14.758169174194336, grad_norm = 17.583349227905273 | |
I1111 18:57:19.790554 140143604311872 torch_e2e.py:143] 43) loss = 13.962091445922852, grad_norm = 16.807153701782227 | |
I1111 18:57:20.587406 140143604311872 torch_e2e.py:143] 44) loss = 13.139028549194336, grad_norm = 15.904189109802246 | |
I1111 18:57:21.384323 140143604311872 torch_e2e.py:143] 45) loss = 12.383018493652344, grad_norm = 14.943195343017578 | |
I1111 18:57:22.180801 140143604311872 torch_e2e.py:143] 46) loss = 11.665107727050781, grad_norm = 13.905163764953613 | |
I1111 18:57:22.979599 140143604311872 torch_e2e.py:143] 47) loss = 10.987071990966797, grad_norm = 12.793498039245605 | |
I1111 18:57:23.774337 140143604311872 torch_e2e.py:143] 48) loss = 10.358452796936035, grad_norm = 11.621541976928711 | |
I1111 18:57:24.569179 140143604311872 torch_e2e.py:143] 49) loss = 9.786215782165527, grad_norm = 10.415675163269043 | |
I1111 18:57:25.367716 140143604311872 torch_e2e.py:143] 50) loss = 9.264466285705566, grad_norm = 9.165142059326172 | |
I1111 18:57:26.164860 140143604311872 torch_e2e.py:143] 51) loss = 8.827374458312988, grad_norm = 7.97666072845459 | |
I1111 18:57:26.960957 140143604311872 torch_e2e.py:143] 52) loss = 8.432060241699219, grad_norm = 6.76027250289917 | |
I1111 18:57:27.761851 140143604311872 torch_e2e.py:143] 53) loss = 8.101480484008789, grad_norm = 5.599133491516113 | |
I1111 18:57:28.555984 140143604311872 torch_e2e.py:143] 54) loss = 7.838068008422852, grad_norm = 4.5281829833984375 | |
I1111 18:57:29.352050 140143604311872 torch_e2e.py:143] 55) loss = 7.619318008422852, grad_norm = 3.4904284477233887 | |
I1111 18:57:30.147860 140143604311872 torch_e2e.py:143] 56) loss = 7.455406665802002, grad_norm = 2.54789662361145 | |
I1111 18:57:30.943805 140143604311872 torch_e2e.py:143] 57) loss = 7.346614360809326, grad_norm = 1.7550326585769653 | |
I1111 18:57:31.742405 140143604311872 torch_e2e.py:143] 58) loss = 7.276298999786377, grad_norm = 1.0543296337127686 | |
I1111 18:57:32.537595 140143604311872 torch_e2e.py:143] 59) loss = 7.240423679351807, grad_norm = 0.5265750885009766 | |
I1111 18:57:33.333628 140143604311872 torch_e2e.py:143] 60) loss = 7.236795425415039, grad_norm = 0.480295866727829 | |
I1111 18:57:34.131216 140143604311872 torch_e2e.py:143] 61) loss = 7.256988048553467, grad_norm = 0.8030412793159485 | |
I1111 18:57:34.932935 140143604311872 torch_e2e.py:143] 62) loss = 7.296460151672363, grad_norm = 1.1590831279754639 | |
I1111 18:57:35.725716 140143604311872 torch_e2e.py:143] 63) loss = 7.349276065826416, grad_norm = 1.4545279741287231 | |
I1111 18:57:36.523608 140143604311872 torch_e2e.py:143] 64) loss = 7.411477565765381, grad_norm = 1.7063922882080078 | |
I1111 18:57:37.323120 140143604311872 torch_e2e.py:143] 65) loss = 7.481031894683838, grad_norm = 1.9173399209976196 | |
I1111 18:57:38.120862 140143604311872 torch_e2e.py:143] 66) loss = 7.552356719970703, grad_norm = 2.088557004928589 | |
I1111 18:57:38.916439 140143604311872 torch_e2e.py:143] 67) loss = 7.622851848602295, grad_norm = 2.2259576320648193 | |
I1111 18:57:39.713130 140143604311872 torch_e2e.py:143] 68) loss = 7.695225715637207, grad_norm = 2.343404769897461 | |
I1111 18:57:40.512876 140143604311872 torch_e2e.py:143] 69) loss = 7.762754440307617, grad_norm = 2.4399759769439697 | |
I1111 18:57:41.310517 140143604311872 torch_e2e.py:143] 70) loss = 7.825161457061768, grad_norm = 2.5152499675750732 | |
I1111 18:57:42.105275 140143604311872 torch_e2e.py:143] 71) loss = 7.885710716247559, grad_norm = 2.581749677658081 | |
I1111 18:57:42.904766 140143604311872 torch_e2e.py:143] 72) loss = 7.9366374015808105, grad_norm = 2.6322216987609863 | |
I1111 18:57:43.701193 140143604311872 torch_e2e.py:143] 73) loss = 7.984954357147217, grad_norm = 2.6764440536499023 | |
I1111 18:57:44.499818 140143604311872 torch_e2e.py:143] 74) loss = 8.027819633483887, grad_norm = 2.7141380310058594 | |
I1111 18:57:45.294919 140143604311872 torch_e2e.py:143] 75) loss = 8.060647964477539, grad_norm = 2.740922689437866 | |
I1111 18:57:46.094676 140143604311872 torch_e2e.py:143] 76) loss = 8.086572647094727, grad_norm = 2.7629334926605225 | |
I1111 18:57:46.896962 140143604311872 torch_e2e.py:143] 77) loss = 8.109407424926758, grad_norm = 2.783541679382324 | |
I1111 18:57:47.694501 140143604311872 torch_e2e.py:143] 78) loss = 8.125061988830566, grad_norm = 2.797971248626709 | |
I1111 18:57:48.491031 140143604311872 torch_e2e.py:143] 79) loss = 8.133414268493652, grad_norm = 2.808274745941162 | |
I1111 18:57:49.288383 140143604311872 torch_e2e.py:143] 80) loss = 8.135396957397461, grad_norm = 2.8145511150360107 | |
I1111 18:57:50.087886 140143604311872 torch_e2e.py:143] 81) loss = 8.132173538208008, grad_norm = 2.819377899169922 | |
I1111 18:57:50.885110 140143604311872 torch_e2e.py:143] 82) loss = 8.122648239135742, grad_norm = 2.819749116897583 | |
I1111 18:57:51.681151 140143604311872 torch_e2e.py:143] 83) loss = 8.107420921325684, grad_norm = 2.8163645267486572 | |
I1111 18:57:52.479457 140143604311872 torch_e2e.py:143] 84) loss = 8.086889266967773, grad_norm = 2.8111491203308105 | |
I1111 18:57:53.276332 140143604311872 torch_e2e.py:143] 85) loss = 8.05954647064209, grad_norm = 2.799006462097168 | |
I1111 18:57:54.073843 140143604311872 torch_e2e.py:143] 86) loss = 8.027091026306152, grad_norm = 2.784994125366211 | |
I1111 18:57:54.869232 140143604311872 torch_e2e.py:143] 87) loss = 7.9884209632873535, grad_norm = 2.764768123626709 | |
I1111 18:57:55.669992 140143604311872 torch_e2e.py:143] 88) loss = 7.947974681854248, grad_norm = 2.742222785949707 | |
I1111 18:57:56.470902 140143604311872 torch_e2e.py:143] 89) loss = 7.900345802307129, grad_norm = 2.7122137546539307 | |
I1111 18:57:57.267645 140143604311872 torch_e2e.py:143] 90) loss = 7.849843978881836, grad_norm = 2.6777143478393555 | |
I1111 18:57:58.063938 140143604311872 torch_e2e.py:143] 91) loss = 7.796028137207031, grad_norm = 2.635615825653076 | |
I1111 18:57:58.861551 140143604311872 torch_e2e.py:143] 92) loss = 7.7386155128479, grad_norm = 2.585268497467041 | |
I1111 18:57:59.659476 140143604311872 torch_e2e.py:143] 93) loss = 7.676787376403809, grad_norm = 2.5245890617370605 | |
I1111 18:58:00.456698 140143604311872 torch_e2e.py:143] 94) loss = 7.611974716186523, grad_norm = 2.4550864696502686 | |
I1111 18:58:01.255294 140143604311872 torch_e2e.py:143] 95) loss = 7.5449748039245605, grad_norm = 2.3730602264404297 | |
I1111 18:58:02.051496 140143604311872 torch_e2e.py:143] 96) loss = 7.480135440826416, grad_norm = 2.2816708087921143 | |
I1111 18:58:02.850461 140143604311872 torch_e2e.py:143] 97) loss = 7.412574768066406, grad_norm = 2.1708152294158936 | |
I1111 18:58:03.647884 140143604311872 torch_e2e.py:143] 98) loss = 7.346048355102539, grad_norm = 2.0456466674804688 | |
PyTorch program execution took 83.153489112854 seconds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment