Created
September 16, 2020 14:57
-
-
Save Deepayan137/5e3febbc8bfc7b926dac472864ce7242 to your computer and use it in GitHub Desktop.
a toy dataset for seq2seq implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import Dataset | |
import numpy as np | |
import pdb | |
class DummyDataset(Dataset): | |
def __init__(self, prob, vocab_size=None, | |
nSamples=None, max_len=None): | |
self.prob = prob | |
if not vocab_size: vocab_size = 10 | |
if not nSamples: nSamples = 20 | |
if not max_len: max_len = 5 | |
self.vocab_size = vocab_size | |
self.nSamples = nSamples | |
self.max_len = max_len | |
self.src_data = self._prepare_src_data() | |
def __len__(self): | |
return self.nSamples | |
def __getitem__(self, index): | |
assert index < self.nSamples | |
src = self.src_data[index] | |
tgt = [self.get_target_id(x) for x in | |
src] | |
return {'src':src, 'tgt':tgt} | |
def func1(self, x): | |
return x//2 | |
def func2(self, x): | |
return 2*x + 1 | |
def get_target_id(self, x): | |
if np.random.random() > self.prob: | |
return self.func1(x) | |
return self.func2(x) | |
def sample_src_ids(self): | |
src_len = self.get_src_len() | |
return np.random.choice(self.vocab_size, | |
src_len) | |
def get_src_len(self): | |
return np.random.randint(1,self.max_len) | |
def _prepare_src_data(self): | |
return [self.sample_src_ids() for | |
i in range(self.nSamples)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment