Created
August 3, 2020 12:49
-
-
Save Taekyoon/d7696d27a95d0c49a7b957edd811dbef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SquadExample: | |
def __init__(self, question, context, start_char_idx, answer_text, all_answers): | |
self.question = " ".join(str(question).split()) | |
self.context = " ".join(str(context).split()) | |
self.answer_text = " ".join(str(answer_text).split()) | |
self.start_char_idx = start_char_idx | |
self.all_answers = all_answers | |
self.skip = False | |
def get_answer_position(self, context, tokenized_context, answer, start_char_idx): | |
# Find end character index of answer in context | |
end_char_idx = start_char_idx + len(answer) | |
if end_char_idx >= len(context): | |
self.skip = True | |
return None, None | |
# Mark the character indexes in context that are in answer | |
is_char_in_ans = [0] * len(context) | |
for idx in range(start_char_idx, end_char_idx): | |
is_char_in_ans[idx] = 1 | |
# Find tokens that were created from answer characters | |
ans_token_idx = [] | |
for idx, (start, end) in enumerate(tokenized_context.offsets): | |
if sum(is_char_in_ans[start:end]) > 0: | |
ans_token_idx.append(idx) | |
if len(ans_token_idx) == 0: | |
self.skip = True | |
return None, None | |
# Find start and end token index for tokens from answer | |
start_token_idx = ans_token_idx[0] | |
end_token_idx = ans_token_idx[-1] | |
return start_token_idx, end_token_idx | |
def get_indexed_inputs(self, tokenized_context, tokenized_question): | |
# Create inputs | |
input_ids = tokenized_context.ids + tokenized_question.ids[1:] | |
token_type_ids = [0] * len(tokenized_context.ids) + [1] * len( | |
tokenized_question.ids[1:] | |
) | |
attention_mask = [1] * len(input_ids) | |
padding_length = max_len - len(input_ids) | |
if padding_length > 0: # pad | |
input_ids = input_ids + ([0] * padding_length) | |
attention_mask = attention_mask + ([0] * padding_length) | |
token_type_ids = token_type_ids + ([0] * padding_length) | |
elif padding_length < 0: # skip | |
self.skip = True | |
return None, None, None | |
return input_ids, attention_mask, token_type_ids | |
def preprocess(self): | |
context = self.context | |
question = self.question | |
answer = self.answer_text | |
start_char_idx = self.start_char_idx | |
tokenized_context = tokenizer.encode(self.context) | |
tokenized_question = tokenizer.encode(question) | |
start_token_idx, end_token_idx = self.get_answer_position(context, | |
tokenized_context, | |
answer, | |
start_char_idx) | |
input_ids, attention_mask, token_type_ids = self.get_indexed_inputs(tokenized_context, | |
tokenized_question) | |
self.input_ids = input_ids | |
self.token_type_ids = token_type_ids | |
self.attention_mask = attention_mask | |
self.start_token_idx = start_token_idx | |
self.end_token_idx = end_token_idx | |
self.context_token_to_char = tokenized_context.offsets | |
with open(train_path) as f: | |
raw_train_data = json.load(f) | |
with open(eval_path) as f: | |
raw_eval_data = json.load(f) | |
def create_squad_examples(raw_data): | |
squad_examples = [] | |
for item in raw_data["data"]: | |
for para in item["paragraphs"]: | |
context = para["context"] | |
for qa in para["qas"]: | |
question = qa["question"] | |
answer_text = qa["answers"][0]["text"] | |
all_answers = [_["text"] for _ in qa["answers"]] | |
start_char_idx = qa["answers"][0]["answer_start"] | |
squad_eg = SquadExample( | |
question, context, start_char_idx, answer_text, all_answers | |
) | |
squad_eg.preprocess() | |
squad_examples.append(squad_eg) | |
return squad_examples | |
def create_inputs_targets(squad_examples): | |
dataset_dict = { | |
"input_ids": [], | |
"token_type_ids": [], | |
"attention_mask": [], | |
"start_token_idx": [], | |
"end_token_idx": [], | |
} | |
for item in squad_examples: | |
if item.skip == False: | |
for key in dataset_dict: | |
dataset_dict[key].append(getattr(item, key)) | |
for key in dataset_dict: | |
dataset_dict[key] = np.array(dataset_dict[key]) | |
x = [ | |
dataset_dict["input_ids"], | |
dataset_dict["token_type_ids"], | |
dataset_dict["attention_mask"], | |
] | |
y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]] | |
return x, y | |
train_squad_examples = create_squad_examples(raw_train_data) | |
x_train, y_train = create_inputs_targets(train_squad_examples) | |
print(f"{len(train_squad_examples)} training points created.") | |
eval_squad_examples = create_squad_examples(raw_eval_data) | |
x_eval, y_eval = create_inputs_targets(eval_squad_examples) | |
print(f"{len(eval_squad_examples)} evaluation points created.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment