Created
June 27, 2018 15:43
-
-
Save rebeccabilbro/a9a3143ff0b20a51f17b65de6284890e to your computer and use it in GitHub Desktop.
Load the yellowbrick hobbies corpus
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from sklearn.datasets.base import Bunch | |
from yellowbrick.download import download_all | |
## The path to the test data sets | |
FIXTURES = os.path.join(os.getcwd(), "data") | |
## Dataset loading mechanisms | |
datasets = { | |
"hobbies": os.path.join(FIXTURES, "hobbies") | |
} | |
def load_data(name, download=True): | |
""" | |
Loads and wrangles the passed in text corpus by name. | |
If download is specified, this method will download any missing files. | |
""" | |
# Get the path from the datasets | |
path = datasets[name] | |
# Check if the data exists, otherwise download or raise | |
if not os.path.exists(path): | |
if download: | |
download_all() | |
else: | |
raise ValueError(( | |
"'{}' dataset has not been downloaded, " | |
"use the download.py module to fetch datasets" | |
).format(name)) | |
# Read the directories in the directory as the categories. | |
categories = [ | |
cat for cat in os.listdir(path) | |
if os.path.isdir(os.path.join(path, cat)) | |
] | |
files = [] # holds the file names relative to the root | |
data = [] # holds the text read from the file | |
target = [] # holds the string of the category | |
# Load the data from the files in the corpus | |
for cat in categories: | |
for name in os.listdir(os.path.join(path, cat)): | |
files.append(os.path.join(path, cat, name)) | |
target.append(cat) | |
with open(os.path.join(path, cat, name), 'r') as f: | |
data.append(f.read()) | |
# Return the data bunch for use similar to the newsgroups example | |
return Bunch( | |
categories=categories, | |
files=files, | |
data=data, | |
target=target, | |
) | |
corpus = load_data('hobbies') | |
hobby_types = {} | |
for category in corpus.categories: | |
texts = [] | |
for idx in range(len(corpus.data)): | |
if corpus['target'][idx] == category: | |
texts.append(' '.join(corpus.data[idx].split())) | |
hobby_types[category] = texts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment