Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active February 15, 2018 11:56
Show Gist options
  • Save justheuristic/8a31871fb78620200be9f124fc961563 to your computer and use it in GitHub Desktop.
Save justheuristic/8a31871fb78620200be9f124fc961563 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import numpy as np
import tensorflow as tf
from collections import Counter
# Auxiliary functions for sequence masks
def infer_length(seq, eos=1, time_major=False):
"""
compute length given output indices and eos code
:param seq: tf matrix [time,batch] if time_major else [batch,time]
:param eos: integer index of end-of-sentence token
:returns: lengths, int32 vector of [batch_size]
"""
axis = 0 if time_major else 1
is_eos = tf.cast(tf.equal(seq, eos), 'int32')
count_eos = tf.cumsum(is_eos, axis=axis, exclusive=True)
lengths = tf.reduce_sum(tf.cast(tf.equal(count_eos, 0), 'int32'), axis=axis)
return lengths
def infer_mask(seq, eos=1, time_major=False, dtype=tf.bool):
"""
compute mask
:param seq: tf matrix [time,batch] if time_major else [batch,time]
:param eos: integer index of end-of-sentence token
:returns: mask, matrix of same shape as seq and of given dtype (bool by default)
"""
lengths = infer_length(seq, eos=eos, time_major=time_major)
maxlen = tf.shape(seq)[0 if time_major else 1]
mask = tf.sequence_mask(lengths, dtype=dtype, maxlen=maxlen)
if time_major:
mask = tf.transpose(mask) # [batch, len] -> [len, batch
return mask
def safe_cumprod(tensor, axis=0, exclusive=True):
transpose_permutation = None
n_dim = len(tensor.get_shape())
if n_dim > 1 and axis != 0:
if axis < 0:
axis = n_dim + axis
transpose_permutation = np.arange(n_dim)
transpose_permutation[-1], transpose_permutation[0] = 0, axis
tensor = tf.transpose(tensor, transpose_permutation)
if exclusive:
tensor = tf.concat([tf.ones_like(tensor[:1]), tensor[:-1]], axis=0)
def prod(acc, x):
return acc * x
prob = tf.scan(prod, tensor)
tensor = tf.transpose(prob, transpose_permutation)
return tensor
def soft_mask(is_eos, time_major=False):
"""
A differentiable version of sequence mask for gumbel
:param is_eos: matrix with probabilities of EOS at each tick
:type is_eos: tf matrix[time,batch] if time_major or [batch,time]
"""
axis = 0 if time_major else 1
return safe_cumprod(1 - is_eos, axis=axis, exclusive=True)
# BLEU implementation
def get_ngrams(ref, n=4):
""" return all n-grams, [batch, time - n + 1, n] """
ref_len = tf.shape(ref)[1]
ref_ngrams = tf.stack([
ref[:, i: ref_len + i - n + 1] for i in range(n)
], axis=-1)
return ref_ngrams
def count_unique_ngrams(ngrams, pad_with=1, count_dtype='int32'):
"""
returns all unique ngrams and their respective counts. CPU-only and non-differentiable
"""
assert ngrams.shape.ndims == 3, "ngrams must be int[batch, ngram_ix, token_ix_in_ngram]"
def _select_unique_ngrams(ngrams):
unique_ngrams = []
ngram_counts = []
for ngrams_for_sample in ngrams:
ngrams_ctr = Counter([tuple(ngram) for ngram in ngrams_for_sample])
unique_ngrams_i, counts_i = zip(*ngrams_ctr.most_common())
unique_ngrams.append(list(unique_ngrams_i))
ngram_counts.append(list(counts_i))
pad_ngram = tuple(pad_with for _ in ngrams[0, 0])
max_unique_ngrams = max(map(len, unique_ngrams))
for i in range(len(unique_ngrams)):
pad_size = max_unique_ngrams - len(unique_ngrams[i])
unique_ngrams[i].extend([pad_ngram] * pad_size)
ngram_counts[i].extend([0] * pad_size)
return np.array(unique_ngrams, ngrams.dtype), np.array(ngram_counts, count_dtype)
unique_ngrams, counts = tf.py_func(_select_unique_ngrams, [ngrams], [ngrams.dtype, count_dtype])
unique_ngrams.set_shape([ngrams.shape[0], None, ngrams.shape[2]])
counts.set_shape([ngrams.shape[0], None])
return unique_ngrams, counts
def soft_count_ngrams(probs, ref_ngrams):
"""
Selects probs for all ngrams in ref.
:returns: For each n-gram enumerates all it's possible positions in probs.
[batch size, num_ngrams in probs, num_ngrams in ref, token index in ngram]
float[batch_size, probs_len - n + 1, ref_len - n + 1, n]
:param probs: predicted probabilities, float[batch, time, num_units]
:param ref_ngrams: ngrams to count, int[batch, num_ngrams, n]
"""
# cast everything to [batch, time, ngram_index, token_index_in_ngram]
ngrams_shape = [tf.shape(ref_ngrams)[i] for i in range(ref_ngrams.shape.ndims)]
batch_size, num_ngrams, ngram_size = ngrams_shape
seq_len = tf.shape(probs)[1]
batch_ix = tf.tile(tf.range(0, batch_size)[:, None, None, None],
[1, seq_len - ngram_size + 1, num_ngrams, ngram_size])
time_ix = tf.tile(tf.range(0, seq_len - ngram_size + 1)[None, :, None, None] + \
tf.range(0, ngram_size)[None, None, None, :],
[batch_size, 1, num_ngrams, 1])
selector = tf.tile(ref_ngrams[:, None, :, :], [1, seq_len - ngram_size + 1, 1, 1])
# flat select [batch, num_ngrams_in_pred, num_ngrams_in_ref * ngram_size]
indices_nd = tf.stack([tf.reshape(batch_ix, [batch_size, seq_len - ngram_size + 1, -1]),
tf.reshape(time_ix, [batch_size, seq_len - ngram_size + 1, -1]),
tf.reshape(selector, [batch_size, seq_len - ngram_size + 1, -1])], axis=-1)
ngram_probs_flat = tf.gather_nd(probs, indices_nd)
# reshape back into [batch, num_ngrams_in_pred, num_ngrams_in_ref, ngram_size]
ngram_probs = tf.reshape(ngram_probs_flat,
[batch_size, seq_len - ngram_size + 1, num_ngrams, ngram_size])
return ngram_probs
def compute_ngram_stats(probs, ref, eos=1, n=4, smoothing=None):
"""
computes soft n-gram precision and recall in a differentiable way
:param probs: predicted probabilities, float[batch, time, num_units]
:param refs: reference token indices, int[batch, time]
:param smoothing: if not None, adds this value to both numerator and denominator of precision & recall
"""
ref_mask = infer_mask(ref, eos)
pred_soft_mask = soft_mask(probs[..., eos])
ref_ngrams = get_ngrams(ref, n=n)
ref_ngrams, ref_ngrams_counts = count_unique_ngrams(ref_ngrams, pad_with=eos)
# mask-out any n-gram that covers anything after first EOS
ref_ngrams_mask = tf.equal(tf.reduce_sum(tf.cast(tf.equal(ref_ngrams, eos), 'int32'), axis=-1), 0)
# dimensions: batch, prob position, ngram index, token within ngram
probs_masked = probs * pred_soft_mask[:, :, None]
ngram_token_probs_elwise = soft_count_ngrams(probs_masked, ref_ngrams)
ngram_probs_elwise = tf.reduce_prod(ngram_token_probs_elwise, axis=3)
ngram_soft_counts = tf.reduce_sum(ngram_probs_elwise, axis=1)
ref_ngrams_counts = tf.cast(ref_ngrams_counts, 'float32') * tf.cast(ref_ngrams_mask, 'float32')
# compute precision / recall
true_positive = tf.reduce_sum(tf.minimum(ngram_soft_counts, ref_ngrams_counts), -1)
all_true = tf.nn.relu(tf.reduce_sum(ref_ngrams_counts, -1))
# there are len - n + 1 ngrams; substract 1 for EOS. If there's less than N predicted, return 0.
# NOTE: this approximation is different (simpler) than what's in the original article.
all_positive = tf.nn.relu(tf.reduce_sum(pred_soft_mask, axis=1) - n)
if smoothing is not None:
true_positive += smoothing
all_positive += smoothing
all_true += smoothing
precision = true_positive / all_positive
recall = true_positive / all_true
return precision, recall
def compute_bleu_with_logits(logits, ref, eos=1, min_n=1, max_n=4, weights=None, smoothing=None):
"""
Computes differentiable BLEU with logits.
:param logits: logits for predicted probabilities, float[batch, time, num_units]
:param refs: reference token indices, int[batch, time]
:param eos: token index for EOS
:param min_n: minimum ngram length (inclusive)
:param max_n: maximum ngram length (inclusive)
:param weights: weights for ngrams, defaults to 1 / num_ngrams
:param smoothing: if not None, adds this value to numerator and denominator of all precisions.
Default smooth BLEU has smoothing=1
:returns: differentiable BLEU, float[batch_size,]
"""
if weights is None:
weights = [1. / max_n] * max_n
else:
assert np.shape(weights)[0] == (
max_n - min_n + 1), "There must be exactly as many weights as there are ngrams (%i)" % (max_n - min_n + 1)
probs = tf.nn.softmax(logits, -1)
precisions = []
for n in range(min_n, max_n + 1):
precision_n, recall_n = compute_ngram_stats(probs, ref, eos, n=n, smoothing=smoothing)
precisions.append(precision_n)
ref_len = tf.cast(infer_length(ref, eos), 'float32')
pred_soft_len = tf.reduce_sum(soft_mask(probs[..., eos]), axis=-1)
brevity_penalty = tf.minimum(1.0, tf.exp(1 - ref_len / pred_soft_len))
bleu = brevity_penalty * tf.exp(sum(w * tf.log(pr) for w, pr in zip(weights, precisions)))
return bleu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment