Last active
February 15, 2018 11:56
-
-
Save justheuristic/8a31871fb78620200be9f124fc961563 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from collections import Counter | |
# Auxiliary functions for sequence masks | |
def infer_length(seq, eos=1, time_major=False): | |
""" | |
compute length given output indices and eos code | |
:param seq: tf matrix [time,batch] if time_major else [batch,time] | |
:param eos: integer index of end-of-sentence token | |
:returns: lengths, int32 vector of [batch_size] | |
""" | |
axis = 0 if time_major else 1 | |
is_eos = tf.cast(tf.equal(seq, eos), 'int32') | |
count_eos = tf.cumsum(is_eos, axis=axis, exclusive=True) | |
lengths = tf.reduce_sum(tf.cast(tf.equal(count_eos, 0), 'int32'), axis=axis) | |
return lengths | |
def infer_mask(seq, eos=1, time_major=False, dtype=tf.bool): | |
""" | |
compute mask | |
:param seq: tf matrix [time,batch] if time_major else [batch,time] | |
:param eos: integer index of end-of-sentence token | |
:returns: mask, matrix of same shape as seq and of given dtype (bool by default) | |
""" | |
lengths = infer_length(seq, eos=eos, time_major=time_major) | |
maxlen = tf.shape(seq)[0 if time_major else 1] | |
mask = tf.sequence_mask(lengths, dtype=dtype, maxlen=maxlen) | |
if time_major: | |
mask = tf.transpose(mask) # [batch, len] -> [len, batch | |
return mask | |
def safe_cumprod(tensor, axis=0, exclusive=True): | |
transpose_permutation = None | |
n_dim = len(tensor.get_shape()) | |
if n_dim > 1 and axis != 0: | |
if axis < 0: | |
axis = n_dim + axis | |
transpose_permutation = np.arange(n_dim) | |
transpose_permutation[-1], transpose_permutation[0] = 0, axis | |
tensor = tf.transpose(tensor, transpose_permutation) | |
if exclusive: | |
tensor = tf.concat([tf.ones_like(tensor[:1]), tensor[:-1]], axis=0) | |
def prod(acc, x): | |
return acc * x | |
prob = tf.scan(prod, tensor) | |
tensor = tf.transpose(prob, transpose_permutation) | |
return tensor | |
def soft_mask(is_eos, time_major=False): | |
""" | |
A differentiable version of sequence mask for gumbel | |
:param is_eos: matrix with probabilities of EOS at each tick | |
:type is_eos: tf matrix[time,batch] if time_major or [batch,time] | |
""" | |
axis = 0 if time_major else 1 | |
return safe_cumprod(1 - is_eos, axis=axis, exclusive=True) | |
# BLEU implementation | |
def get_ngrams(ref, n=4): | |
""" return all n-grams, [batch, time - n + 1, n] """ | |
ref_len = tf.shape(ref)[1] | |
ref_ngrams = tf.stack([ | |
ref[:, i: ref_len + i - n + 1] for i in range(n) | |
], axis=-1) | |
return ref_ngrams | |
def count_unique_ngrams(ngrams, pad_with=1, count_dtype='int32'): | |
""" | |
returns all unique ngrams and their respective counts. CPU-only and non-differentiable | |
""" | |
assert ngrams.shape.ndims == 3, "ngrams must be int[batch, ngram_ix, token_ix_in_ngram]" | |
def _select_unique_ngrams(ngrams): | |
unique_ngrams = [] | |
ngram_counts = [] | |
for ngrams_for_sample in ngrams: | |
ngrams_ctr = Counter([tuple(ngram) for ngram in ngrams_for_sample]) | |
unique_ngrams_i, counts_i = zip(*ngrams_ctr.most_common()) | |
unique_ngrams.append(list(unique_ngrams_i)) | |
ngram_counts.append(list(counts_i)) | |
pad_ngram = tuple(pad_with for _ in ngrams[0, 0]) | |
max_unique_ngrams = max(map(len, unique_ngrams)) | |
for i in range(len(unique_ngrams)): | |
pad_size = max_unique_ngrams - len(unique_ngrams[i]) | |
unique_ngrams[i].extend([pad_ngram] * pad_size) | |
ngram_counts[i].extend([0] * pad_size) | |
return np.array(unique_ngrams, ngrams.dtype), np.array(ngram_counts, count_dtype) | |
unique_ngrams, counts = tf.py_func(_select_unique_ngrams, [ngrams], [ngrams.dtype, count_dtype]) | |
unique_ngrams.set_shape([ngrams.shape[0], None, ngrams.shape[2]]) | |
counts.set_shape([ngrams.shape[0], None]) | |
return unique_ngrams, counts | |
def soft_count_ngrams(probs, ref_ngrams): | |
""" | |
Selects probs for all ngrams in ref. | |
:returns: For each n-gram enumerates all it's possible positions in probs. | |
[batch size, num_ngrams in probs, num_ngrams in ref, token index in ngram] | |
float[batch_size, probs_len - n + 1, ref_len - n + 1, n] | |
:param probs: predicted probabilities, float[batch, time, num_units] | |
:param ref_ngrams: ngrams to count, int[batch, num_ngrams, n] | |
""" | |
# cast everything to [batch, time, ngram_index, token_index_in_ngram] | |
ngrams_shape = [tf.shape(ref_ngrams)[i] for i in range(ref_ngrams.shape.ndims)] | |
batch_size, num_ngrams, ngram_size = ngrams_shape | |
seq_len = tf.shape(probs)[1] | |
batch_ix = tf.tile(tf.range(0, batch_size)[:, None, None, None], | |
[1, seq_len - ngram_size + 1, num_ngrams, ngram_size]) | |
time_ix = tf.tile(tf.range(0, seq_len - ngram_size + 1)[None, :, None, None] + \ | |
tf.range(0, ngram_size)[None, None, None, :], | |
[batch_size, 1, num_ngrams, 1]) | |
selector = tf.tile(ref_ngrams[:, None, :, :], [1, seq_len - ngram_size + 1, 1, 1]) | |
# flat select [batch, num_ngrams_in_pred, num_ngrams_in_ref * ngram_size] | |
indices_nd = tf.stack([tf.reshape(batch_ix, [batch_size, seq_len - ngram_size + 1, -1]), | |
tf.reshape(time_ix, [batch_size, seq_len - ngram_size + 1, -1]), | |
tf.reshape(selector, [batch_size, seq_len - ngram_size + 1, -1])], axis=-1) | |
ngram_probs_flat = tf.gather_nd(probs, indices_nd) | |
# reshape back into [batch, num_ngrams_in_pred, num_ngrams_in_ref, ngram_size] | |
ngram_probs = tf.reshape(ngram_probs_flat, | |
[batch_size, seq_len - ngram_size + 1, num_ngrams, ngram_size]) | |
return ngram_probs | |
def compute_ngram_stats(probs, ref, eos=1, n=4, smoothing=None): | |
""" | |
computes soft n-gram precision and recall in a differentiable way | |
:param probs: predicted probabilities, float[batch, time, num_units] | |
:param refs: reference token indices, int[batch, time] | |
:param smoothing: if not None, adds this value to both numerator and denominator of precision & recall | |
""" | |
ref_mask = infer_mask(ref, eos) | |
pred_soft_mask = soft_mask(probs[..., eos]) | |
ref_ngrams = get_ngrams(ref, n=n) | |
ref_ngrams, ref_ngrams_counts = count_unique_ngrams(ref_ngrams, pad_with=eos) | |
# mask-out any n-gram that covers anything after first EOS | |
ref_ngrams_mask = tf.equal(tf.reduce_sum(tf.cast(tf.equal(ref_ngrams, eos), 'int32'), axis=-1), 0) | |
# dimensions: batch, prob position, ngram index, token within ngram | |
probs_masked = probs * pred_soft_mask[:, :, None] | |
ngram_token_probs_elwise = soft_count_ngrams(probs_masked, ref_ngrams) | |
ngram_probs_elwise = tf.reduce_prod(ngram_token_probs_elwise, axis=3) | |
ngram_soft_counts = tf.reduce_sum(ngram_probs_elwise, axis=1) | |
ref_ngrams_counts = tf.cast(ref_ngrams_counts, 'float32') * tf.cast(ref_ngrams_mask, 'float32') | |
# compute precision / recall | |
true_positive = tf.reduce_sum(tf.minimum(ngram_soft_counts, ref_ngrams_counts), -1) | |
all_true = tf.nn.relu(tf.reduce_sum(ref_ngrams_counts, -1)) | |
# there are len - n + 1 ngrams; substract 1 for EOS. If there's less than N predicted, return 0. | |
# NOTE: this approximation is different (simpler) than what's in the original article. | |
all_positive = tf.nn.relu(tf.reduce_sum(pred_soft_mask, axis=1) - n) | |
if smoothing is not None: | |
true_positive += smoothing | |
all_positive += smoothing | |
all_true += smoothing | |
precision = true_positive / all_positive | |
recall = true_positive / all_true | |
return precision, recall | |
def compute_bleu_with_logits(logits, ref, eos=1, min_n=1, max_n=4, weights=None, smoothing=None): | |
""" | |
Computes differentiable BLEU with logits. | |
:param logits: logits for predicted probabilities, float[batch, time, num_units] | |
:param refs: reference token indices, int[batch, time] | |
:param eos: token index for EOS | |
:param min_n: minimum ngram length (inclusive) | |
:param max_n: maximum ngram length (inclusive) | |
:param weights: weights for ngrams, defaults to 1 / num_ngrams | |
:param smoothing: if not None, adds this value to numerator and denominator of all precisions. | |
Default smooth BLEU has smoothing=1 | |
:returns: differentiable BLEU, float[batch_size,] | |
""" | |
if weights is None: | |
weights = [1. / max_n] * max_n | |
else: | |
assert np.shape(weights)[0] == ( | |
max_n - min_n + 1), "There must be exactly as many weights as there are ngrams (%i)" % (max_n - min_n + 1) | |
probs = tf.nn.softmax(logits, -1) | |
precisions = [] | |
for n in range(min_n, max_n + 1): | |
precision_n, recall_n = compute_ngram_stats(probs, ref, eos, n=n, smoothing=smoothing) | |
precisions.append(precision_n) | |
ref_len = tf.cast(infer_length(ref, eos), 'float32') | |
pred_soft_len = tf.reduce_sum(soft_mask(probs[..., eos]), axis=-1) | |
brevity_penalty = tf.minimum(1.0, tf.exp(1 - ref_len / pred_soft_len)) | |
bleu = brevity_penalty * tf.exp(sum(w * tf.log(pr) for w, pr in zip(weights, precisions))) | |
return bleu |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment