Skip to content

Instantly share code, notes, and snippets.

@Taekyoon
Created August 3, 2020 12:49
Show Gist options
  • Save Taekyoon/d7696d27a95d0c49a7b957edd811dbef to your computer and use it in GitHub Desktop.
Save Taekyoon/d7696d27a95d0c49a7b957edd811dbef to your computer and use it in GitHub Desktop.
class SquadExample:
def __init__(self, question, context, start_char_idx, answer_text, all_answers):
self.question = " ".join(str(question).split())
self.context = " ".join(str(context).split())
self.answer_text = " ".join(str(answer_text).split())
self.start_char_idx = start_char_idx
self.all_answers = all_answers
self.skip = False
def get_answer_position(self, context, tokenized_context, answer, start_char_idx):
# Find end character index of answer in context
end_char_idx = start_char_idx + len(answer)
if end_char_idx >= len(context):
self.skip = True
return None, None
# Mark the character indexes in context that are in answer
is_char_in_ans = [0] * len(context)
for idx in range(start_char_idx, end_char_idx):
is_char_in_ans[idx] = 1
# Find tokens that were created from answer characters
ans_token_idx = []
for idx, (start, end) in enumerate(tokenized_context.offsets):
if sum(is_char_in_ans[start:end]) > 0:
ans_token_idx.append(idx)
if len(ans_token_idx) == 0:
self.skip = True
return None, None
# Find start and end token index for tokens from answer
start_token_idx = ans_token_idx[0]
end_token_idx = ans_token_idx[-1]
return start_token_idx, end_token_idx
def get_indexed_inputs(self, tokenized_context, tokenized_question):
# Create inputs
input_ids = tokenized_context.ids + tokenized_question.ids[1:]
token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(
tokenized_question.ids[1:]
)
attention_mask = [1] * len(input_ids)
padding_length = max_len - len(input_ids)
if padding_length > 0: # pad
input_ids = input_ids + ([0] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
elif padding_length < 0: # skip
self.skip = True
return None, None, None
return input_ids, attention_mask, token_type_ids
def preprocess(self):
context = self.context
question = self.question
answer = self.answer_text
start_char_idx = self.start_char_idx
tokenized_context = tokenizer.encode(self.context)
tokenized_question = tokenizer.encode(question)
start_token_idx, end_token_idx = self.get_answer_position(context,
tokenized_context,
answer,
start_char_idx)
input_ids, attention_mask, token_type_ids = self.get_indexed_inputs(tokenized_context,
tokenized_question)
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.start_token_idx = start_token_idx
self.end_token_idx = end_token_idx
self.context_token_to_char = tokenized_context.offsets
with open(train_path) as f:
raw_train_data = json.load(f)
with open(eval_path) as f:
raw_eval_data = json.load(f)
def create_squad_examples(raw_data):
squad_examples = []
for item in raw_data["data"]:
for para in item["paragraphs"]:
context = para["context"]
for qa in para["qas"]:
question = qa["question"]
answer_text = qa["answers"][0]["text"]
all_answers = [_["text"] for _ in qa["answers"]]
start_char_idx = qa["answers"][0]["answer_start"]
squad_eg = SquadExample(
question, context, start_char_idx, answer_text, all_answers
)
squad_eg.preprocess()
squad_examples.append(squad_eg)
return squad_examples
def create_inputs_targets(squad_examples):
dataset_dict = {
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"start_token_idx": [],
"end_token_idx": [],
}
for item in squad_examples:
if item.skip == False:
for key in dataset_dict:
dataset_dict[key].append(getattr(item, key))
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["token_type_ids"],
dataset_dict["attention_mask"],
]
y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
return x, y
train_squad_examples = create_squad_examples(raw_train_data)
x_train, y_train = create_inputs_targets(train_squad_examples)
print(f"{len(train_squad_examples)} training points created.")
eval_squad_examples = create_squad_examples(raw_eval_data)
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment