Last active
May 1, 2016 14:17
-
-
Save odashi/788623ad7028a1a53ad0 to your computer and use it in GitHub Desktop.
Minimum error-rate training for statistical machine translation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import math | |
import random | |
import sys | |
from argparse import ArgumentParser | |
from collections import defaultdict | |
from util.functions import trace | |
def parse_args(): | |
def_epoch = 10 | |
def_metrics = 'BLEU' | |
p = ArgumentParser( | |
description='MERT trainer', | |
usage= | |
'\n %(prog)s [options] reference nbest model' | |
'\n %(prog)s -h', | |
) | |
p.add_argument('reference', | |
help='[in] reference corpus') | |
p.add_argument('nbest', | |
help='[in] Travatar hypothesis corpus') | |
p.add_argument('model', | |
help='[out] model prefix') | |
p.add_argument('--epoch', | |
default=def_epoch, metavar='INT', type=int, | |
help='number of training epoch (default: %(default)d)') | |
p.add_argument('--init-weight', | |
default=None, metavar='FILE', type=str, | |
help='initial weight file (default: %(default)s)') | |
#p.add_argument('--metrics', | |
# default=def_epoch, metavar='STR', type=str, | |
# help='evaluation metrics name (default: %(default)d)') | |
args = p.parse_args() | |
# check args | |
if not args.epoch >= 1: raise ValueError('Not satisfy a condition --epoch >= 1') | |
#if not args.metrics in ['BLEU']: raise ValueError('Not satisfy a condition --epoch in {BLEU}') | |
return args | |
def gen_ref(filename): | |
with open(filename) as fp: | |
for line in fp: | |
yield line.split() | |
def gen_hyps(filename): | |
prev_sid = 0 | |
hyp_batch = [] | |
feature_batch = [] | |
with open(filename) as fp: | |
for line in fp: | |
sid, hyp, _, feature = line.strip().split(' ||| ') | |
sid = int(sid) | |
hyp = hyp.split() | |
feature = [x.split('=') for x in feature.split()] | |
feature = defaultdict(float, {x[0]: float(x[1]) for x in feature}) | |
if sid != prev_sid: | |
yield prev_sid, hyp_batch, feature_batch | |
prev_sid = sid | |
hyp_batch = [] | |
feature_batch = [] | |
hyp_batch.append(hyp) | |
feature_batch.append(feature) | |
if hyp_batch: | |
yield prev_sid, hyp_batch, feature_batch | |
def get_default_stats(N=4): | |
return [0 for _ in range(2 * N + 2)] | |
def get_bleu_stats(ref, hyp, N=4): | |
# stats[2n]: candidate N-gram | |
# stats[2n+1]: matched N-gram | |
stats = [0 for _ in range(2 * N)] | |
for n in range(len(hyp) if len(hyp) < N else N): | |
matched = 0 | |
possible = defaultdict(int) | |
for k in range(len(ref) - n): | |
possible[tuple(ref[k : k + n + 1])] += 1 | |
for k in range(len(hyp) - n): | |
ngram = tuple(hyp[k : k + n + 1]) | |
if possible[ngram] > 0: | |
possible[ngram] -= 1 | |
matched += 1 | |
stats[2 * n] = len(hyp) - n | |
stats[2 * n + 1] = matched | |
return stats + [len(ref), len(hyp)] | |
def calculate_bleu(stats, N=4): | |
np = 0.0 | |
for n in range(N): | |
nn = stats[2 * n + 1] | |
if nn == 0: | |
return 0.0 | |
np += math.log(nn) - math.log(stats[2 * n]) | |
bp = 1.0 - stats[-2] / stats[-1] | |
if bp > 0.0: bp = 0.0 | |
return math.exp(np / N + bp) | |
def get_feature_names(filename): | |
keys = set() | |
for sid, __, feature_batch in gen_hyps(filename): | |
for features in feature_batch: | |
keys |= set(features) | |
trace(sid, rollback=True) | |
return keys | |
def get_grad(feature, key): | |
return feature[key] | |
def get_bias(feature, weights, keys): | |
return sum(feature[k] * weights[k] for k in keys) | |
def accum_stats(dest, src): | |
for i in range(len(dest)): | |
dest[i] += src[i] | |
def get_diff_stats(a, b): | |
return [b[i] - a[i] for i in range(len(a))] | |
def train(args, epoch, weights): | |
for target_axis in weights: | |
const_axis = set(weights) - {target_axis} | |
trace('epoch %4d: weight: %s' % (epoch, target_axis)) | |
total_stats = get_default_stats() | |
diff_stats_list = [] | |
for ref, (sid, hyp_batch, feature_batch) in zip(gen_ref(args.reference), gen_hyps(args.nbest)): | |
# aw+b, where w is the target weight | |
a_batch = [get_grad(feature, target_axis) for feature in feature_batch] | |
b_batch = [get_bias(feature, weights, const_axis) for feature in feature_batch] | |
stats_batch = [get_bleu_stats(ref, hyp) for hyp in hyp_batch] | |
# ordering by small gradient, and large bias | |
batch = sorted(zip(a_batch, b_batch, stats_batch), key=lambda x: (x[0], -x[1])) | |
prev_n = 0 | |
prev_w = -1e20 # watchdog | |
accum_stats(total_stats, batch[0][2]) | |
while True: | |
next_n = None | |
next_w = 1e20 # watchdog | |
for n in range(prev_n + 1, len(batch)): | |
if batch[n][0] == batch[prev_n][0]: | |
continue # ignore same gradients | |
# update intersection | |
w = (batch[n][1] - batch[prev_n][1]) / (batch[prev_n][0] - batch[n][0]) | |
if prev_w < w <= next_w: | |
next_n = n | |
next_w = w | |
if next_n is None: | |
break # no more intersection | |
diff_stats_list.append((next_w, get_diff_stats(batch[prev_n][2], batch[next_n][2]))) | |
prev_n = next_n | |
prev_w = next_w | |
trace(sid, rollback=True) | |
best_bleu = calculate_bleu(total_stats) | |
if len(diff_stats_list) > 0: | |
# find global optimum over the focused axis | |
diff_stats_list = sorted(diff_stats_list, key=lambda x: x[0]) | |
best_m = -1 | |
for m, (w, diff) in enumerate(diff_stats_list): | |
accum_stats(total_stats, diff) | |
if m == len(diff_stats_list) - 1 or w < diff_stats_list[m + 1][0]: | |
bleu = calculate_bleu(total_stats) | |
if bleu > best_bleu: | |
best_m = m | |
best_bleu = bleu | |
# update weight | |
if best_m == -1: | |
weights[target_axis] = diff_stats_list[0][0] - 1.0 | |
elif best_m == len(diff_stats_list) - 1: | |
weights[target_axis] = diff_stats_list[-1][0] + 1.0 | |
else: | |
weights[target_axis] = \ | |
0.5 * (diff_stats_list[best_m][0] + diff_stats_list[best_m + 1][0]) | |
else: | |
# no intersection | |
weights[target_axis] = 0.0 | |
# verify | |
#total_stats = get_default_stats() | |
#all_weights = set(weights) | |
#for ref, (sid, hyp_batch, feature_batch) in zip(gen_ref(args.reference), gen_hyps(args.nbest)): | |
# best_hyp = None | |
# best_score = -1e20 | |
# for i, (hyp, feature) in enumerate(zip(hyp_batch, feature_batch)): | |
# score = get_bias(feature, weights, all_weights) | |
# if score > best_score: | |
# best_hyp = hyp | |
# best_score = score | |
# accum_stats(total_stats, get_bleu_stats(ref, best_hyp)) | |
#bleu_verify = calculate_bleu(total_stats) | |
#if best_bleu != bleu_verify: | |
# raise RuntimeError('abort') | |
trace('%8s = %+.6f, BLEU = %.6f' % (target_axis, weights[target_axis], best_bleu)) | |
def init_weights(args): | |
weight = {k: random.uniform(-1, 1) for k in get_feature_names(args.nbest)} | |
if args.init_weight is not None: | |
weight = {k: 0.0 for k in weight} | |
with open(args.init_weight) as fp: | |
for line in fp: | |
key, value = line.split() | |
weight[key] = float(value) | |
return weight | |
def save(weights, filename): | |
with open(filename, 'w') as fp: | |
for k, v in weights.items(): | |
print('%s\t%+.8e' % (k, v), file=fp) | |
def main(): | |
args = parse_args() | |
trace('gathering weights ...') | |
weights = init_weights(args) | |
trace('start training ...') | |
for i in range(args.epoch): | |
train(args, i, weights) | |
save(weights, args.model + '.%04d' % (i + 1)) | |
trace('finished.') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment