Skip to content

Instantly share code, notes, and snippets.

@bayerj
Created October 21, 2016 14:10
Show Gist options
  • Save bayerj/73aa2d6a3a3e5db14df071b91bcb2d22 to your computer and use it in GitHub Desktop.
Save bayerj/73aa2d6a3a3e5db14df071b91bcb2d22 to your computer and use it in GitHub Desktop.
class _SequentialAutoregressive(tf.contrib.distributions.Distribution):
def __init__(self, f_process, base_dist_cls, initial_dist,
n_time_steps=None,
dtype=tf.float32,
name='sequential_auto_regressive'):
self.f_process = f_process
self.base_dist_cls = base_dist_cls
self.initial_dist = initial_dist
self.n_time_steps = n_time_steps
super(_SequentialAutoregressive, self).__init__(
dtype=dtype,
parameters={},
is_continuous=True,
is_reparameterized=False,
validate_args=False,
allow_nan_stats=False,
name=name
)
def _log_prob(self, value):
log_pdf0 = self.initial_dist.log_pdf(value[0, :, :])
_, stats = self.f_process(
value[:-1, :, :],
value.get_shape()[0].value - 1, None)
rv = self.base_dist_cls(*stats)
log_pdfs = rv.log_pdf(value[1:, :, :])
return tf.concat(0, [tf.expand_dims(log_pdf0, 0), log_pdfs])
def _sample_n(self, n, seed=None):
obs = [tf.reshape(self.initial_dist.sample_n(n), (n, -1))]
state_tm1 = None
or i in range(self.n_time_steps - 1):
states, stats = self.f_process(obs[-1:], 1, state_tm1)
state_tm1 = states[-1]
rv = self.base_dist_cls(*stats)
obs.append(tf.reshape(rv.sample(1)[0], (1, -1)))
return tf.pack(obs)
def _sample(self, sample_shape=(), seed=None):
raise NotImplemented
def _get_batch_shape(self):
return tf.TensorShape([self.n_time_steps])
def _get_event_shape(self):
return tf.TensorShape([None, None])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment