Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active July 8, 2025 10:28
Show Gist options
  • Save crowsonkb/7c9767706316c4dc78037e0430a51fc1 to your computer and use it in GitHub Desktop.
Save crowsonkb/7c9767706316c4dc78037e0430a51fc1 to your computer and use it in GitHub Desktop.
The energy matching loss
"""The energy matching loss.
Energy matching regresses an energy function to match a target energy function at the points in the
dataset. ("Energy" refers to an unnormalized negative log probability: for a sequence model it is
the sum of the cross-entropy losses of a sequence's completion tokens plus some constant. Two energy
functions are considered "the same" by the energy matching loss if they differ by an arbitrary
constant.) This is useful for on-policy reinforcement learning or off-policy preference tuning of
sequence models, where the target energies are:
[the sequences' energies under the reference model] - [the sequences' rewards] / beta,
where `beta` is the KL regularization coefficient.
When evaluated on samples drawn from the current energy function (that is, when used on-policy), the
gradient of the energy matching loss is an unbiased estimator of the gradient of the KL divergence
of the target energy function from the current energy function.
"""
import torch
def energy_matching_loss(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute the energy matching loss.
Args:
input (torch.Tensor): The input energies. Each row represents a group of at least two
energies of independent samples from the same distribution. In the case of a sequence
model, this means they all have the same prompt. Shape: (..., num_samples).
target (torch.Tensor): The target energies. Shape: (..., num_samples).
Returns:
torch.Tensor: The energy matching loss. Shape: ().
"""
if input.shape != target.shape:
raise ValueError("Input and target must have the same shape.")
if input.shape[-1] < 2:
raise ValueError("Input and target must have at least two elements in the last dimension.")
error = input - target
error = error - torch.mean(error, dim=-1, keepdim=True)
return torch.sum(error**2) / (2 * error[..., 1:].numel())
def energy_matching_loss_with_group_indices(
input: torch.Tensor, target: torch.Tensor, indices: torch.Tensor
) -> torch.Tensor:
"""Compute the energy matching loss with manually specified group indices. This is used when you
want to mix groups with different numbers of samples in the same batch.
Args:
input (torch.Tensor): A 1D tensor of input energies.
target (torch.Tensor): A 1D tensor of target energies.
indices (torch.Tensor): A 1D tensor of group indices, where each index corresponds to a
group of independent samples from the same distribution. In the case of a sequence
model, this means they all have the same prompt. Each group must have at least two
samples.
Returns:
torch.Tensor: The energy matching loss. Shape: ().
"""
if input.shape != target.shape:
raise ValueError("Input and target must have the same shape.")
if input.shape != indices.shape:
raise ValueError("Input and indices must have the same shape.")
if input.ndim != 1:
raise ValueError("Input, target, and indices must be 1D tensors.")
k = indices.bincount()[indices]
if torch.any(k < 2):
raise ValueError("Each group must have at least two samples.")
num_groups = indices.amax() + 1
error = input - target
error = error - error.new_zeros(num_groups).index_add_(0, indices, error)[indices] / k
return torch.mean(error**2 * k / (k - 1)) / 2
def sequence_energy(logits: torch.Tensor, tokens: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Compute a sequence model's energies for sequences of tokens. (These are just the sums of the
cross-entropy losses of each sequence's completion tokens.)
Args:
logits (torch.Tensor): Logits of the sequences. Shape: (..., seq_len, vocab_size).
tokens (torch.Tensor): Tokens of the sequences. Shape: (..., seq_len).
mask (torch.Tensor): `True` for completion tokens, `False` for prompt or padding tokens.
Shape: (..., seq_len).
Returns:
torch.Tensor: The energies for the sequences. Shape: (...).
"""
lse = torch.logsumexp(logits[..., :-1, :], dim=-1)
pos = torch.gather(logits[..., :-1, :], -1, tokens[..., 1:, None])[..., 0]
return torch.sum((lse - pos) * mask[..., 1:], dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment