Last active
July 8, 2025 10:28
-
-
Save crowsonkb/7c9767706316c4dc78037e0430a51fc1 to your computer and use it in GitHub Desktop.
The energy matching loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""The energy matching loss. | |
Energy matching regresses an energy function to match a target energy function at the points in the | |
dataset. ("Energy" refers to an unnormalized negative log probability: for a sequence model it is | |
the sum of the cross-entropy losses of a sequence's completion tokens plus some constant. Two energy | |
functions are considered "the same" by the energy matching loss if they differ by an arbitrary | |
constant.) This is useful for on-policy reinforcement learning or off-policy preference tuning of | |
sequence models, where the target energies are: | |
[the sequences' energies under the reference model] - [the sequences' rewards] / beta, | |
where `beta` is the KL regularization coefficient. | |
When evaluated on samples drawn from the current energy function (that is, when used on-policy), the | |
gradient of the energy matching loss is an unbiased estimator of the gradient of the KL divergence | |
of the target energy function from the current energy function. | |
""" | |
import torch | |
def energy_matching_loss(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
"""Compute the energy matching loss. | |
Args: | |
input (torch.Tensor): The input energies. Each row represents a group of at least two | |
energies of independent samples from the same distribution. In the case of a sequence | |
model, this means they all have the same prompt. Shape: (..., num_samples). | |
target (torch.Tensor): The target energies. Shape: (..., num_samples). | |
Returns: | |
torch.Tensor: The energy matching loss. Shape: (). | |
""" | |
if input.shape != target.shape: | |
raise ValueError("Input and target must have the same shape.") | |
if input.shape[-1] < 2: | |
raise ValueError("Input and target must have at least two elements in the last dimension.") | |
error = input - target | |
error = error - torch.mean(error, dim=-1, keepdim=True) | |
return torch.sum(error**2) / (2 * error[..., 1:].numel()) | |
def energy_matching_loss_with_group_indices( | |
input: torch.Tensor, target: torch.Tensor, indices: torch.Tensor | |
) -> torch.Tensor: | |
"""Compute the energy matching loss with manually specified group indices. This is used when you | |
want to mix groups with different numbers of samples in the same batch. | |
Args: | |
input (torch.Tensor): A 1D tensor of input energies. | |
target (torch.Tensor): A 1D tensor of target energies. | |
indices (torch.Tensor): A 1D tensor of group indices, where each index corresponds to a | |
group of independent samples from the same distribution. In the case of a sequence | |
model, this means they all have the same prompt. Each group must have at least two | |
samples. | |
Returns: | |
torch.Tensor: The energy matching loss. Shape: (). | |
""" | |
if input.shape != target.shape: | |
raise ValueError("Input and target must have the same shape.") | |
if input.shape != indices.shape: | |
raise ValueError("Input and indices must have the same shape.") | |
if input.ndim != 1: | |
raise ValueError("Input, target, and indices must be 1D tensors.") | |
k = indices.bincount()[indices] | |
if torch.any(k < 2): | |
raise ValueError("Each group must have at least two samples.") | |
num_groups = indices.amax() + 1 | |
error = input - target | |
error = error - error.new_zeros(num_groups).index_add_(0, indices, error)[indices] / k | |
return torch.mean(error**2 * k / (k - 1)) / 2 | |
def sequence_energy(logits: torch.Tensor, tokens: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
"""Compute a sequence model's energies for sequences of tokens. (These are just the sums of the | |
cross-entropy losses of each sequence's completion tokens.) | |
Args: | |
logits (torch.Tensor): Logits of the sequences. Shape: (..., seq_len, vocab_size). | |
tokens (torch.Tensor): Tokens of the sequences. Shape: (..., seq_len). | |
mask (torch.Tensor): `True` for completion tokens, `False` for prompt or padding tokens. | |
Shape: (..., seq_len). | |
Returns: | |
torch.Tensor: The energies for the sequences. Shape: (...). | |
""" | |
lse = torch.logsumexp(logits[..., :-1, :], dim=-1) | |
pos = torch.gather(logits[..., :-1, :], -1, tokens[..., 1:, None])[..., 0] | |
return torch.sum((lse - pos) * mask[..., 1:], dim=-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment