Last active
February 14, 2018 19:02
-
-
Save loristns/8d83f85e975c843afad2e47d58d1c423 to your computer and use it in GitHub Desktop.
PyTorch linear regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Variable | |
from torch.nn import MSELoss | |
points = [ | |
(0, 20), | |
(2, 16), | |
(6, 14), | |
(4, 10), | |
(10, 10), | |
(10, 4) | |
] | |
points = Variable(torch.FloatTensor(points)) | |
m = Variable(torch.rand(1), requires_grad=True) | |
b = Variable(torch.rand(1), requires_grad=True) | |
loss_function = MSELoss() | |
learning_rate = 0.01 | |
for step in range(1000): | |
x, y = points[:, 0], points[:, 1] | |
y_pred = m * x + b | |
error = loss_function(y_pred, y) | |
error.backward() # Calcul des gradients | |
print("Step: {} - Loss: {}".format(step, float(error))) | |
# .data permet d'utiliser le type FloatTensor et non Variable. | |
m.data -= learning_rate * m.grad.data # On fait évoluer m et b | |
b.data -= learning_rate * b.grad.data | |
# On remet la matrice des gradients à zéro. | |
m.grad.data.zero_() | |
b.grad.data.zero_() | |
print('\n Learning finished !') | |
print('\tf(x) = {m}x + {b}'.format(m=float(m), b=float(b))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment