Created
April 13, 2022 00:20
-
-
Save AnirudhDagar/d674419224adf608ec76d3c0b9ce1f00 to your computer and use it in GitHub Desktop.
This adds #@tab mxnet to all the missing code cells so that by default pytorch is picked up when order of tabs is changed in config.ini
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from d2lbook import notebook | |
import nbformat | |
import os | |
import glob | |
import re | |
import nbconvert | |
def find_files(): | |
""" | |
Find all jupyter notebooks from the root dir. | |
""" | |
fnames = [] | |
for fn in glob.glob('*/*.md', recursive=True): | |
if os.path.isfile(fn): | |
fnames.append(fn) | |
return fnames | |
def add_mxnet_mark(markdowns): | |
""" | |
Save .md files with the newly inserted mxnet mark. | |
""" | |
for fn in markdowns: | |
with open(fn, 'r') as f: | |
content = f.readlines() | |
for idx, line in enumerate(content): | |
use_percent = False | |
if line.startswith("```{.python .input"): | |
if content[idx+1].startswith("%%tab") or content[idx+1].startswith("%load_ext"): | |
use_percent = True | |
continue | |
if not content[idx+1].startswith("#@tab"): | |
print("MXNet Cell Detected...") | |
if not use_percent: | |
content[idx+1] = "#@tab mxnet\n" + content[idx+1] | |
else: | |
print(f"Using percent symbol in {fn}") | |
content[idx+1] = "%%tab mxnet\n" + content[idx+1] | |
with open(fn, 'w') as f: | |
f.writelines(content) | |
if __name__ == "__main__": | |
print("Adding #@tab mxnet mark in notebooks...") | |
all_notebooks = find_files() | |
add_mxnet_mark(all_notebooks) | |
print("Done!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment