Skip to content

Instantly share code, notes, and snippets.

@zachcp
Last active November 27, 2024 14:39
Show Gist options
  • Save zachcp/45ae897bd0db389b6a288a99d25011bd to your computer and use it in GitHub Desktop.
Save zachcp/45ae897bd0db389b6a288a99d25011bd to your computer and use it in GitHub Desktop.
Pytorch Sonnet Explanations

I am working on a translation of this pytorch code to Rust/Candle. Can you step through this code line by line and term by term. Please provide me the shapes of the tensors at each step and be sure to mention where there is implicit broadcasting.

for t_ in range(L):
               t = decoding_order[:, t_]  # [B]
               chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0]  # [B]
               mask_t = torch.gather(mask, 1, t[:, None])[:, 0]  # [B]
               bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[
                   :, 0, :
               ]  # [B,21]

               E_idx_t = torch.gather(
                   E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1])
               )
               h_E_t = torch.gather(
                   h_E,
                   1,
                   t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]),
               )
               h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
               h_EXV_encoder_t = torch.gather(
                   h_EXV_encoder_fw,
                   1,
                   t[:, None, None, None].repeat(
                       1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]
                   ),
               )

               mask_bw_t = torch.gather(
                   mask_bw,
                   1,
                   t[:, None, None, None].repeat(
                       1, 1, mask_bw.shape[-2], mask_bw.shape[-1]
                   ),
               )

def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
    h_nodes = gather_nodes(h_nodes, E_idx)
    h_nn = torch.cat([h_neighbors, h_nodes], -1)
    return h_nn
  
  def gather_nodes(nodes, neighbor_idx):
      # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
      # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
      neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1))
      neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
      # Gather and re-pack
      neighbor_features = torch.gather(nodes, 1, neighbors_flat)
      neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
      return neighbor_features  
      

I'll break this down line by line, explaining the shapes and operations:

for t_ in range(L):
    # L is the sequence length
    
    t = decoding_order[:, t_]  # [B]
    # decoding_order shape: [B, L] (Batch, Length)
    # t shape: [B] - selects a column from decoding_order
    
    chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0]  # [B]
    # chain_mask shape: [B, L]
    # t[:, None] adds dimension: [B, 1]
    # gather selects elements along dim=1 using indices
    # [:, 0] removes the added dimension
    # Final shape: [B]
    
    mask_t = torch.gather(mask, 1, t[:, None])[:, 0]  # [B]
    # Similar to chain_mask_t operation
    # Final shape: [B]
    
    bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[:, 0, :]  # [B,21]
    # bias shape: [B, L, 21]
    # t[:, None, None] shape: [B, 1, 1]
    # repeat(1, 1, 21) broadcasts to [B, 1, 21]
    # gather selects elements, then [:, 0, :] removes middle dimension
    # Final shape: [B, 21]
    
    E_idx_t = torch.gather(E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1]))
    # E_idx shape: [B, L, K] (K is number of neighbors)
    # Final shape: [B, 1, K]
    
    h_E_t = torch.gather(h_E, 1, t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]))
    # h_E shape: [B, L, K, C] (C is feature dimension)
    # Final shape: [B, 1, K, C]
    
    h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
    # h_S shape: [B, N, C] (N is number of nodes)
    # Inside cat_neighbors_nodes:
    #   1. gather_nodes gets features for neighbors: [B, 1, K, C]
    #   2. concatenates with h_E_t along last dimension
    # Final shape: [B, 1, K, 2C]
    
    h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1, 
        t[:, None, None, None].repeat(1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]))
    # h_EXV_encoder_fw shape: [B, L, K', C'] (K' might be different from K)
    # Final shape: [B, 1, K', C']
    
    mask_bw_t = torch.gather(mask_bw, 1,
        t[:, None, None, None].repeat(1, 1, mask_bw.shape[-2], mask_bw.shape[-1]))
    # mask_bw shape: [B, L, K', C']
    # Final shape: [B, 1, K', C']

The gather_nodes function:

def gather_nodes(nodes, neighbor_idx):
    # nodes shape: [B, N, C]
    # neighbor_idx shape: [B, N', K]
    
    # Flatten neighbor indices
    neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1))
    # shape: [B, N'K]
    
    # Add feature dimension and broadcast
    neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
    # shape: [B, N'K, C]
    
    # Gather features
    neighbor_features = torch.gather(nodes, 1, neighbors_flat)
    # shape: [B, N'K, C]
    
    # Reshape to match input structure
    neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
    # Final shape: [B, N', K, C]
    return neighbor_features

The most important broadcasting happens in the .repeat() operations, where single indices are expanded to match the desired output dimensions. This is particularly evident in the gathering operations where t[:, None, None] or t[:, None, None, None] is repeated to match the target tensor dimensions.

Okay. My dimensions seem fine during my encoding code but not in my inference code. The correct code has input dimensions of:

IN GATHER NODES. Nodes, neighbor_idx: [1, 93, 128], [1, 93, 24]

But the incorrect code has dimensions of:

 Nodes, neighbor_idx: [1, 93, 128], [1, 1, 24]

Can you guess where this is happening in this code? E.g. an implicit 1->93 conversion that is not happening in Rust where it must be explicit?

Ah, I see the issue. Looking at the code, this difference most likely stems from the E_idx_t gathering operation:

E_idx_t = torch.gather(E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1]))

In the original code, E_idx has shape [B, L, K] where L=93 and K=24. When we gather using t, the PyTorch operation is maintaining the full neighbor structure somehow, resulting in [1, 93, 24].

However, in your Rust code, the gather operation is probably producing [1, 1, 24] because it's literally gathering just one slice without implicit broadcasting.

The key difference might be in how PyTorch handles the gather operation with respect to the non-gathered dimensions. In PyTorch, the gathering might be preserving the structural relationship with all nodes even when selecting a single position, while in Rust/Candle you might need to explicitly broadcast or repeat the gathered indices to maintain the [1, 93, 24] shape.

To fix this in Rust, you might need to explicitly expand/repeat the gathered indices to match the expected shape, something like:

// After gathering, explicitly broadcast from [1, 1, 24] to [1, 93, 24]
let e_idx_t = e_idx_t.expand([1, 93, 24])?;  // pseudo-code

The exact syntax will depend on Candle's API, but the key is to explicitly restore the middle dimension from 1 to 93 after the gather operation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment