Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Created April 23, 2022 17:46
Show Gist options
  • Save Mason-McGough/c5ec0e4290ece975907d34a72f16727f to your computer and use it in GitHub Desktop.
Save Mason-McGough/c5ec0e4290ece975907d34a72f16727f to your computer and use it in GitHub Desktop.
The NeRF PyTorch module.
class NeRF(nn.Module):
r"""
Neural radiance fields module.
"""
def __init__(
self,
d_input: int = 3,
n_layers: int = 8,
d_filter: int = 256,
skip: Tuple[int] = (4,),
d_viewdirs: Optional[int] = None
):
super().__init__()
self.d_input = d_input
self.skip = skip
self.act = nn.functional.relu
self.d_viewdirs = d_viewdirs
# Create model layers
self.layers = nn.ModuleList(
[nn.Linear(self.d_input, d_filter)] +
[nn.Linear(d_filter + self.d_input, d_filter) if i in skip \
else nn.Linear(d_filter, d_filter) for i in range(n_layers - 1)]
)
# Bottleneck layers
if self.d_viewdirs is not None:
# If using viewdirs, split alpha and RGB
self.alpha_out = nn.Linear(d_filter, 1)
self.rgb_filters = nn.Linear(d_filter, d_filter)
self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)
self.output = nn.Linear(d_filter // 2, 3)
else:
# If no viewdirs, use simpler output
self.output = nn.Linear(d_filter, 4)
def forward(
self,
x: torch.Tensor,
viewdirs: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
Forward pass with optional view direction.
"""
# Cannot use viewdirs if instantiated with d_viewdirs = None
if self.d_viewdirs is None and viewdirs is not None:
raise ValueError('Cannot input x_direction if d_viewdirs was not given.')
# Apply forward pass up to bottleneck
x_input = x
for i, layer in enumerate(self.layers):
x = self.act(layer(x))
if i in self.skip:
x = torch.cat([x, x_input], dim=-1)
# Apply bottleneck
if self.d_viewdirs is not None:
# Split alpha from network output
alpha = self.alpha_out(x)
# Pass through bottleneck to get RGB
x = self.rgb_filters(x)
x = torch.concat([x, viewdirs], dim=-1)
x = self.act(self.branch(x))
x = self.output(x)
# Concatenate alphas to output
x = torch.concat([x, alpha], dim=-1)
else:
# Simple output
x = self.output(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment