Last active
January 15, 2022 12:30
-
-
Save bkaankuguoglu/75a59c4c590454776874da992a94f58c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Optimization: | |
# ... | |
def forecast_with_predictors( | |
self, forecast_loader, batch_size=1, n_features=1, n_steps=100 | |
): | |
"""Forecasts values for RNNs with predictors and one-dimensional output | |
The method takes DataLoader for the test dataset, batch size for mini-batch testing, | |
number of features and number of steps to predict as inputs. Then it generates the | |
future values for RNNs with one-dimensional output for the given n_steps. It uses the | |
values from the predictors columns (features) to forecast the future values. | |
Args: | |
forecast_loader (torch.utils.data.DataLoader): DataLoader that stores test data | |
batch_size (int): Batch size for mini-batch training | |
n_features (int): Number of feature columns | |
n_steps (int): Number of steps to predict future values | |
Returns: | |
list[float]: The values predicted by the model | |
""" | |
step = 0 | |
with torch.no_grad(): | |
predictions = [] | |
for x_test, _ in forecast_loader: | |
x_test = x_test.view([batch_size, -1, n_features]).to(device) | |
self.model.eval() | |
yhat = self.model(x_test) | |
predictions.append(yhat.to(device).detach().numpy()) | |
step += 1 | |
if step == n_steps: | |
break | |
return predictions | |
# ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment