Last active
February 26, 2020 08:55
-
-
Save ramcandrews/ddf3975cc2aea805a3fe580a33c3f5bd to your computer and use it in GitHub Desktop.
Batch data into chunks using the pytorch TensorDataset and Dataloader classes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import TensorDataset, DataLoader | |
import torch | |
# Check for a GPU | |
train_on_gpu = torch.cuda.is_available() | |
if not train_on_gpu: | |
print('No GPU found. Please use a GPU to train your neural network.') | |
def batch_data(words, sequence_length, batch_size): | |
""" | |
Batch the neural network data using DataLoader | |
:param words: The word ids of the TV scripts | |
:param sequence_length: The sequence length of each batch | |
:param batch_size: The size of each batch; the number of sequences in a batch | |
:return: DataLoader with batched data | |
""" | |
feature_tensors = [] | |
target_tensors = [] | |
for i in range(batch_size): | |
feature_tensors.append(words[sequence_length -1]) | |
target_tensors.append(words[sequence_length]) | |
print("feature_tensors: ", feature_tensors) | |
print("target_tensors: ", target_tensors) | |
words = words[1:] | |
data = TensorDataset(torch.Tensor(feature_tensors), torch.Tensor(target_tensors)) | |
print(data) | |
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size) | |
# return a dataloader | |
return data_loader | |
### create test data | |
test_text = range(50) | |
### chunk and load data | |
t_loader = batch_data(test_text, sequence_length=5, batch_size=10) | |
data_iter = iter(t_loader) | |
### test dataloader | |
test_text = range(50) | |
t_loader = batch_data(test_text, sequence_length=5, batch_size=10) | |
data_iter = iter(t_loader) | |
sample_x, sample_y = data_iter.next() | |
print(sample_x.shape) | |
print(sample_x) | |
print() | |
print(sample_y.shape) | |
print(sample_y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment