Last active
August 4, 2022 05:23
-
-
Save jeremytanjianle/9dcb4ca017eadfedec2f75058a3da1cf to your computer and use it in GitHub Desktop.
Basic AllenNLP Classification DatasetReader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@DatasetReader.register('classification-tsv') | |
class ClassificationTsvReader(DatasetReader): | |
def __init__(self, | |
lazy: bool = False, | |
tokenizer: Tokenizer = None, | |
token_indexers: Dict[str, TokenIndexer] = None): | |
super().__init__(lazy) | |
self.tokenizer = tokenizer or WhitespaceTokenizer() | |
self.token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} | |
def text_to_instance(self, text: str, label: str = None) -> Instance: | |
tokens = self.tokenizer.tokenize(text) | |
text_field = TextField(tokens, self.token_indexers) | |
fields = {'text': text_field} | |
if label: | |
fields['label'] = LabelField(label) | |
return Instance(fields) | |
def _read(self, file_path: str) -> Iterable[Instance]: | |
with open(file_path, 'r') as lines: | |
for line in lines: | |
text, sentiment = line.strip().split('\t') | |
yield self.text_to_instance(text, sentiment) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment