I am working on a translation of this pytorch code to Rust/Candle. Can you step through this code line by line and term by term. Please provide me the shapes of the tensors at each step and be sure to mention where there is implicit broadcasting.
for t_ in range(L):
t = decoding_order[:, t_] # [B]
chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0] # [B]
mask_t = torch.gather(mask, 1, t[:, None])[:, 0] # [B]
bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[
:, 0, :
] # [B,21]
E_idx_t = torch.gather(
E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1])
)
h_E_t = torch.gather(
h_E,
1,
t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]),
)
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
h_EXV_encoder_t = torch.gather(
h_EXV_encoder_fw,
1,
t[:, None, None, None].repeat(
1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]
),
)
mask_bw_t = torch.gather(
mask_bw,
1,
t[:, None, None, None].repeat(
1, 1, mask_bw.shape[-2], mask_bw.shape[-1]
),
)
def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
h_nodes = gather_nodes(h_nodes, E_idx)
h_nn = torch.cat([h_neighbors, h_nodes], -1)
return h_nn
def gather_nodes(nodes, neighbor_idx):
# Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
# Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1))
neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
# Gather and re-pack
neighbor_features = torch.gather(nodes, 1, neighbors_flat)
neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
return neighbor_features
I'll break this down line by line, explaining the shapes and operations:
for t_ in range(L):
# L is the sequence length
t = decoding_order[:, t_] # [B]
# decoding_order shape: [B, L] (Batch, Length)
# t shape: [B] - selects a column from decoding_order
chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0] # [B]
# chain_mask shape: [B, L]
# t[:, None] adds dimension: [B, 1]
# gather selects elements along dim=1 using indices
# [:, 0] removes the added dimension
# Final shape: [B]
mask_t = torch.gather(mask, 1, t[:, None])[:, 0] # [B]
# Similar to chain_mask_t operation
# Final shape: [B]
bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[:, 0, :] # [B,21]
# bias shape: [B, L, 21]
# t[:, None, None] shape: [B, 1, 1]
# repeat(1, 1, 21) broadcasts to [B, 1, 21]
# gather selects elements, then [:, 0, :] removes middle dimension
# Final shape: [B, 21]
E_idx_t = torch.gather(E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1]))
# E_idx shape: [B, L, K] (K is number of neighbors)
# Final shape: [B, 1, K]
h_E_t = torch.gather(h_E, 1, t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]))
# h_E shape: [B, L, K, C] (C is feature dimension)
# Final shape: [B, 1, K, C]
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
# h_S shape: [B, N, C] (N is number of nodes)
# Inside cat_neighbors_nodes:
# 1. gather_nodes gets features for neighbors: [B, 1, K, C]
# 2. concatenates with h_E_t along last dimension
# Final shape: [B, 1, K, 2C]
h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1,
t[:, None, None, None].repeat(1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]))
# h_EXV_encoder_fw shape: [B, L, K', C'] (K' might be different from K)
# Final shape: [B, 1, K', C']
mask_bw_t = torch.gather(mask_bw, 1,
t[:, None, None, None].repeat(1, 1, mask_bw.shape[-2], mask_bw.shape[-1]))
# mask_bw shape: [B, L, K', C']
# Final shape: [B, 1, K', C']
The gather_nodes
function:
def gather_nodes(nodes, neighbor_idx):
# nodes shape: [B, N, C]
# neighbor_idx shape: [B, N', K]
# Flatten neighbor indices
neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1))
# shape: [B, N'K]
# Add feature dimension and broadcast
neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
# shape: [B, N'K, C]
# Gather features
neighbor_features = torch.gather(nodes, 1, neighbors_flat)
# shape: [B, N'K, C]
# Reshape to match input structure
neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
# Final shape: [B, N', K, C]
return neighbor_features
The most important broadcasting happens in the .repeat()
operations, where single indices are expanded to match the desired output dimensions. This is particularly evident in the gathering operations where t[:, None, None]
or t[:, None, None, None]
is repeated to match the target tensor dimensions.
Okay. My dimensions seem fine during my encoding code but not in my inference code. The correct code has input dimensions of:
IN GATHER NODES. Nodes, neighbor_idx: [1, 93, 128], [1, 93, 24]
But the incorrect code has dimensions of:
Nodes, neighbor_idx: [1, 93, 128], [1, 1, 24]
Can you guess where this is happening in this code? E.g. an implicit 1->93 conversion that is not happening in Rust where it must be explicit?
Ah, I see the issue. Looking at the code, this difference most likely stems from the E_idx_t
gathering operation:
E_idx_t = torch.gather(E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1]))
In the original code, E_idx
has shape [B, L, K] where L=93 and K=24. When we gather using t
, the PyTorch operation is maintaining the full neighbor structure somehow, resulting in [1, 93, 24].
However, in your Rust code, the gather operation is probably producing [1, 1, 24] because it's literally gathering just one slice without implicit broadcasting.
The key difference might be in how PyTorch handles the gather operation with respect to the non-gathered dimensions. In PyTorch, the gathering might be preserving the structural relationship with all nodes even when selecting a single position, while in Rust/Candle you might need to explicitly broadcast or repeat the gathered indices to maintain the [1, 93, 24] shape.
To fix this in Rust, you might need to explicitly expand/repeat the gathered indices to match the expected shape, something like:
// After gathering, explicitly broadcast from [1, 1, 24] to [1, 93, 24]
let e_idx_t = e_idx_t.expand([1, 93, 24])?; // pseudo-code
The exact syntax will depend on Candle's API, but the key is to explicitly restore the middle dimension from 1 to 93 after the gather operation.