Skip to content

Instantly share code, notes, and snippets.

@ingted
Created March 5, 2024 23:02
Show Gist options
  • Save ingted/df616af5f0b5802fbdc459db03c6e15e to your computer and use it in GitHub Desktop.
Save ingted/df616af5f0b5802fbdc459db03c6e15e to your computer and use it in GitHub Desktop.
Time2vec001
def t2v(tau, f, out_features, w, b, w0, b0, arg=None):\n if arg:\n v1 = f(torch.matmul(tau, w) + b, arg)\n else:\n v1 = f(torch.matmul(tau, w) + b)\n v2 = torch.matmul(tau, w0) + b0\n return torch.cat([v1, v2], 1)\n\n\nclass SineActivation(nn.Module):\n def __init__(self, in_features, out_features):\n super(SineActivation, self).__init__()\n self.out_features = out_features\n self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))\n self.b0 = nn.parameter.Parameter(torch.randn(in_features, 1))\n self.w = nn.parameter.Parameter(torch.randn(in_features, out_features - 1))\n self.b = nn.parameter.Parameter(torch.randn(in_features, out_features - 1))\n self.f = torch.sin\n\n def forward(self, tau):\n return t2v(tau, self.f, self.out_features, self.w, self.b, self.w0, self.b0)\n\n\nclass CosineActivation(nn.Module):\n def __init__(self, in_features, out_features):\n super(CosineActivation, self).__init__()\n self.out_features = out_features\n self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))\n self.b0 = nn.parameter.Parameter(torch.randn(in_features, 1))\n self.w = nn.parameter.Parameter(torch.randn(in_features, out_features - 1))\n self.b = nn.parameter.Parameter(torch.randn(in_features, out_features - 1))\n self.f = torch.cos\n\n def forward(self, tau):\n return t2v(tau, self.f, self.out_features, self.w, self.b, self.w0, self.b0)\n\n\nclass Time2Vec(nn.Module):\n def __init__(self, activation, hiddem_dim):\n super(Time2Vec, self).__init__()\n if activation == \"sin\":\n self.l1 = SineActivation(1, hiddem_dim)\n elif activation == \"cos\":\n self.l1 = CosineActivation(1, hiddem_dim)\n\n self.fc1 = nn.Linear(hiddem_dim, 2)\n\n def forward(self, x):\n x = self.l1(x)\n x = self.fc1(x)\n return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment