Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active February 15, 2018 11:56
Show Gist options
  • Save justheuristic/8a31871fb78620200be9f124fc961563 to your computer and use it in GitHub Desktop.
Save justheuristic/8a31871fb78620200be9f124fc961563 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=\n",
"data matrix (bach-major, 1 = EOS) \n",
" [[10 11 12 13 14 15 16 1 1 1]\n",
" [20 21 22 23 1 1 1 1 1 1]\n",
" [30 31 32 33 34 35 1 1 1 1]\n",
" [40 41 42 43 44 45 46 47 48 1]\n",
" [50 51 52 53 54 55 56 57 1 1]]\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"tf.reset_default_graph()\n",
"if 'sess' in globals(): sess.close()\n",
"sess = tf.InteractiveSession()\n",
"\n",
"data = np.arange(10,60, dtype='int32').reshape([5,10])\n",
"# WARNING! in current implementation there MUST be at least 1 EOS at the end\n",
"for i, len_i in enumerate([7, 4, 6, 9, 8]):\n",
" data[i, len_i:] = 1\n",
" \n",
" \n",
"print(\"data matrix (bach-major, 1 = EOS) \\n\", data)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from tf_diffbleu import compute_bleu_with_logits, compute_ngram_stats\n",
"\n",
"ref = tf.placeholder_with_default(data, [None, None], name='reference_answers')\n",
"logits = tf.get_variable('logits', [5, 25, 60]) \n",
"\n",
"smooth_bleu = compute_bleu_with_logits(logits, ref, eos=1, smoothing=1)\n",
"hard_bleu = compute_bleu_with_logits(logits, ref, eos=1, smoothing=0)\n",
"precisions, recalls = compute_ngram_stats(tf.nn.softmax(logits), ref, eos=1, n=4)\n",
"\n",
"loss = - tf.reduce_mean(smooth_bleu)\n",
"step = tf.train.AdamOptimizer(0.1).minimize(loss)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"sess.run(tf.global_variables_initializer())\n",
"sess.run(tf.assign(logits, tf.zeros_like(logits)));"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.071 -0.076 -0.081 -0.087 -0.094 -0.101 -0.110 -0.120 -0.131 -0.143 -0.158 -0.175 -0.196 -0.217 -0.234 -0.246 -0.259 -0.264 -0.265 -0.263 -0.263 -0.267 -0.271 -0.276 -0.281 -0.282 -0.285 -0.291 -0.299 -0.303 -0.302 -0.303 -0.304 -0.305 -0.307 -0.309 -0.311 -0.312 -0.314 -0.316 -0.317 -0.320 -0.322 -0.324 -0.327 -0.327 -0.329 -0.331 -0.335 -0.340 -0.345 -0.349 -0.355 -0.363 -0.374 -0.387 -0.404 -0.423 -0.448 -0.475 -0.506 -0.540 -0.575 -0.611 -0.647 -0.680 -0.713 -0.744 -0.772 -0.797 -0.819 -0.839 -0.858 -0.874 -0.889 -0.902 -0.914 -0.924 -0.934 -0.942 -0.949 -0.956 -0.961 -0.966 -0.971 -0.975 -0.978 -0.981 -0.983 -0.985 -0.987 -0.989 -0.990 -0.991 -0.992 -0.993 -0.993 -0.994 -0.994 -0.995 -0.995 -0.995 -0.995 -0.996 -0.996 -0.996 -0.996 -0.996 -0.996 -0.996 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.997 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 -0.998 "
]
}
],
"source": [
"for _ in range(200):\n",
" print('%.3f'%sess.run([loss, step])[0].mean(), end=' '*7)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f67104eee80>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f671053c438>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=[12,12])\n",
"plt.xlabel(\"token index\")\n",
"plt.ylabel(\"batch x time\")\n",
"plt.title(\"predicted probs (token 1 means EOS)\")\n",
"plt.imshow(sess.run(tf.nn.softmax(logits)).reshape([-1, logits.shape[-1]]))\n",
"plt.colorbar()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reference:\n",
"[10, 11, 12, 13, 14, 15, 16, 1]\n",
"Predicted:\n",
"[10, 11, 12, 13, 14, 15, 16, 1]\n",
"\n",
"Reference:\n",
"[20, 21, 22, 23, 1]\n",
"Predicted:\n",
"[20, 21, 22, 23, 1]\n",
"\n",
"Reference:\n",
"[30, 31, 32, 33, 34, 35, 1]\n",
"Predicted:\n",
"[30, 31, 32, 33, 34, 35, 1]\n",
"\n",
"Reference:\n",
"[40, 41, 42, 43, 44, 45, 46, 47, 48, 1]\n",
"Predicted:\n",
"[40, 41, 42, 43, 44, 45, 46, 47, 48, 1]\n",
"\n",
"Reference:\n",
"[50, 51, 52, 53, 54, 55, 56, 57, 1]\n",
"Predicted:\n",
"[50, 51, 52, 53, 54, 55, 56, 57, 1]\n",
"\n"
]
}
],
"source": [
"preds = sess.run(tf.argmax(logits, -1))\n",
"\n",
"for pred, ref in zip(map(list,preds), map(list,data)):\n",
" # crop EOS after first\n",
" if 1 in pred: \n",
" pred = pred[:pred.index(1) + 1]\n",
" if 1 in ref: \n",
" ref = ref[:ref.index(1) + 1]\n",
" \n",
" print(\"Reference:\") \n",
" print(ref)\n",
" print(\"Predicted:\")\n",
" print(pred)\n",
" print()\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import numpy as np
import tensorflow as tf
from collections import Counter
# Auxiliary functions for sequence masks
def infer_length(seq, eos=1, time_major=False):
"""
compute length given output indices and eos code
:param seq: tf matrix [time,batch] if time_major else [batch,time]
:param eos: integer index of end-of-sentence token
:returns: lengths, int32 vector of [batch_size]
"""
axis = 0 if time_major else 1
is_eos = tf.cast(tf.equal(seq, eos), 'int32')
count_eos = tf.cumsum(is_eos, axis=axis, exclusive=True)
lengths = tf.reduce_sum(tf.cast(tf.equal(count_eos, 0), 'int32'), axis=axis)
return lengths
def infer_mask(seq, eos=1, time_major=False, dtype=tf.bool):
"""
compute mask
:param seq: tf matrix [time,batch] if time_major else [batch,time]
:param eos: integer index of end-of-sentence token
:returns: mask, matrix of same shape as seq and of given dtype (bool by default)
"""
lengths = infer_length(seq, eos=eos, time_major=time_major)
maxlen = tf.shape(seq)[0 if time_major else 1]
mask = tf.sequence_mask(lengths, dtype=dtype, maxlen=maxlen)
if time_major:
mask = tf.transpose(mask) # [batch, len] -> [len, batch
return mask
def safe_cumprod(tensor, axis=0, exclusive=True):
transpose_permutation = None
n_dim = len(tensor.get_shape())
if n_dim > 1 and axis != 0:
if axis < 0:
axis = n_dim + axis
transpose_permutation = np.arange(n_dim)
transpose_permutation[-1], transpose_permutation[0] = 0, axis
tensor = tf.transpose(tensor, transpose_permutation)
if exclusive:
tensor = tf.concat([tf.ones_like(tensor[:1]), tensor[:-1]], axis=0)
def prod(acc, x):
return acc * x
prob = tf.scan(prod, tensor)
tensor = tf.transpose(prob, transpose_permutation)
return tensor
def soft_mask(is_eos, time_major=False):
"""
A differentiable version of sequence mask for gumbel
:param is_eos: matrix with probabilities of EOS at each tick
:type is_eos: tf matrix[time,batch] if time_major or [batch,time]
"""
axis = 0 if time_major else 1
return safe_cumprod(1 - is_eos, axis=axis, exclusive=True)
# BLEU implementation
def get_ngrams(ref, n=4):
""" return all n-grams, [batch, time - n + 1, n] """
ref_len = tf.shape(ref)[1]
ref_ngrams = tf.stack([
ref[:, i: ref_len + i - n + 1] for i in range(n)
], axis=-1)
return ref_ngrams
def count_unique_ngrams(ngrams, pad_with=1, count_dtype='int32'):
"""
returns all unique ngrams and their respective counts. CPU-only and non-differentiable
"""
assert ngrams.shape.ndims == 3, "ngrams must be int[batch, ngram_ix, token_ix_in_ngram]"
def _select_unique_ngrams(ngrams):
unique_ngrams = []
ngram_counts = []
for ngrams_for_sample in ngrams:
ngrams_ctr = Counter([tuple(ngram) for ngram in ngrams_for_sample])
unique_ngrams_i, counts_i = zip(*ngrams_ctr.most_common())
unique_ngrams.append(list(unique_ngrams_i))
ngram_counts.append(list(counts_i))
pad_ngram = tuple(pad_with for _ in ngrams[0, 0])
max_unique_ngrams = max(map(len, unique_ngrams))
for i in range(len(unique_ngrams)):
pad_size = max_unique_ngrams - len(unique_ngrams[i])
unique_ngrams[i].extend([pad_ngram] * pad_size)
ngram_counts[i].extend([0] * pad_size)
return np.array(unique_ngrams, ngrams.dtype), np.array(ngram_counts, count_dtype)
unique_ngrams, counts = tf.py_func(_select_unique_ngrams, [ngrams], [ngrams.dtype, count_dtype])
unique_ngrams.set_shape([ngrams.shape[0], None, ngrams.shape[2]])
counts.set_shape([ngrams.shape[0], None])
return unique_ngrams, counts
def soft_count_ngrams(probs, ref_ngrams):
"""
Selects probs for all ngrams in ref.
:returns: For each n-gram enumerates all it's possible positions in probs.
[batch size, num_ngrams in probs, num_ngrams in ref, token index in ngram]
float[batch_size, probs_len - n + 1, ref_len - n + 1, n]
:param probs: predicted probabilities, float[batch, time, num_units]
:param ref_ngrams: ngrams to count, int[batch, num_ngrams, n]
"""
# cast everything to [batch, time, ngram_index, token_index_in_ngram]
ngrams_shape = [tf.shape(ref_ngrams)[i] for i in range(ref_ngrams.shape.ndims)]
batch_size, num_ngrams, ngram_size = ngrams_shape
seq_len = tf.shape(probs)[1]
batch_ix = tf.tile(tf.range(0, batch_size)[:, None, None, None],
[1, seq_len - ngram_size + 1, num_ngrams, ngram_size])
time_ix = tf.tile(tf.range(0, seq_len - ngram_size + 1)[None, :, None, None] + \
tf.range(0, ngram_size)[None, None, None, :],
[batch_size, 1, num_ngrams, 1])
selector = tf.tile(ref_ngrams[:, None, :, :], [1, seq_len - ngram_size + 1, 1, 1])
# flat select [batch, num_ngrams_in_pred, num_ngrams_in_ref * ngram_size]
indices_nd = tf.stack([tf.reshape(batch_ix, [batch_size, seq_len - ngram_size + 1, -1]),
tf.reshape(time_ix, [batch_size, seq_len - ngram_size + 1, -1]),
tf.reshape(selector, [batch_size, seq_len - ngram_size + 1, -1])], axis=-1)
ngram_probs_flat = tf.gather_nd(probs, indices_nd)
# reshape back into [batch, num_ngrams_in_pred, num_ngrams_in_ref, ngram_size]
ngram_probs = tf.reshape(ngram_probs_flat,
[batch_size, seq_len - ngram_size + 1, num_ngrams, ngram_size])
return ngram_probs
def compute_ngram_stats(probs, ref, eos=1, n=4, smoothing=None):
"""
computes soft n-gram precision and recall in a differentiable way
:param probs: predicted probabilities, float[batch, time, num_units]
:param refs: reference token indices, int[batch, time]
:param smoothing: if not None, adds this value to both numerator and denominator of precision & recall
"""
ref_mask = infer_mask(ref, eos)
pred_soft_mask = soft_mask(probs[..., eos])
ref_ngrams = get_ngrams(ref, n=n)
ref_ngrams, ref_ngrams_counts = count_unique_ngrams(ref_ngrams, pad_with=eos)
# mask-out any n-gram that covers anything after first EOS
ref_ngrams_mask = tf.equal(tf.reduce_sum(tf.cast(tf.equal(ref_ngrams, eos), 'int32'), axis=-1), 0)
# dimensions: batch, prob position, ngram index, token within ngram
probs_masked = probs * pred_soft_mask[:, :, None]
ngram_token_probs_elwise = soft_count_ngrams(probs_masked, ref_ngrams)
ngram_probs_elwise = tf.reduce_prod(ngram_token_probs_elwise, axis=3)
ngram_soft_counts = tf.reduce_sum(ngram_probs_elwise, axis=1)
ref_ngrams_counts = tf.cast(ref_ngrams_counts, 'float32') * tf.cast(ref_ngrams_mask, 'float32')
# compute precision / recall
true_positive = tf.reduce_sum(tf.minimum(ngram_soft_counts, ref_ngrams_counts), -1)
all_true = tf.nn.relu(tf.reduce_sum(ref_ngrams_counts, -1))
# there are len - n + 1 ngrams; substract 1 for EOS. If there's less than N predicted, return 0.
# NOTE: this approximation is different (simpler) than what's in the original article.
all_positive = tf.nn.relu(tf.reduce_sum(pred_soft_mask, axis=1) - n)
if smoothing is not None:
true_positive += smoothing
all_positive += smoothing
all_true += smoothing
precision = true_positive / all_positive
recall = true_positive / all_true
return precision, recall
def compute_bleu_with_logits(logits, ref, eos=1, min_n=1, max_n=4, weights=None, smoothing=None):
"""
Computes differentiable BLEU with logits.
:param logits: logits for predicted probabilities, float[batch, time, num_units]
:param refs: reference token indices, int[batch, time]
:param eos: token index for EOS
:param min_n: minimum ngram length (inclusive)
:param max_n: maximum ngram length (inclusive)
:param weights: weights for ngrams, defaults to 1 / num_ngrams
:param smoothing: if not None, adds this value to numerator and denominator of all precisions.
Default smooth BLEU has smoothing=1
:returns: differentiable BLEU, float[batch_size,]
"""
if weights is None:
weights = [1. / max_n] * max_n
else:
assert np.shape(weights)[0] == (
max_n - min_n + 1), "There must be exactly as many weights as there are ngrams (%i)" % (max_n - min_n + 1)
probs = tf.nn.softmax(logits, -1)
precisions = []
for n in range(min_n, max_n + 1):
precision_n, recall_n = compute_ngram_stats(probs, ref, eos, n=n, smoothing=smoothing)
precisions.append(precision_n)
ref_len = tf.cast(infer_length(ref, eos), 'float32')
pred_soft_len = tf.reduce_sum(soft_mask(probs[..., eos]), axis=-1)
brevity_penalty = tf.minimum(1.0, tf.exp(1 - ref_len / pred_soft_len))
bleu = brevity_penalty * tf.exp(sum(w * tf.log(pr) for w, pr in zip(weights, precisions)))
return bleu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment