Last active
October 10, 2018 05:58
-
-
Save whusnoopy/af0aa6fd276ace8a7c4d483e586e936d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf8 | |
# https://www.yewen.us/blog/2018/10/append-machine-learning-3-linear-regression/ | |
from mxnet import autograd, nd | |
num_inputs = 9 # 特征数,当前问题里的变量数 1-9 | |
num_examples = 1000 # 样例数,我们会随机生成多少份样例来学习 | |
true_w = nd.array([0, 0, 0, 0, 0, 1, 0, 2, 1]) # 真实值 | |
features = nd.random.normal(scale=1, shape=(num_examples, num_inputs)) # 随机生成数据集 | |
labels = nd.dot(features, true_w) # 数据集对应的结果 | |
w = nd.random.normal(scale=0.01, shape=(9, 1)) | |
w.attach_grad() | |
def linreg(X, w): | |
return nd.dot(X, w) | |
def squared_loss(y_hat, y): | |
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 | |
def sgd(param, lr, batch_size): | |
param[:] = param - lr * param.grad / batch_size | |
def train(): | |
lr = 0.01 | |
num_epochs = 1000 | |
net = linreg | |
loss = squared_loss | |
for epoch in range(num_epochs): | |
with autograd.record(): | |
l = loss(net(features, w), labels) | |
l.backward() | |
sgd(w, lr, labels.size) | |
train_l = loss(net(features, w), labels) | |
if epoch % 100 == 99: | |
print('epoch {}, loss {}, w {}'.format(epoch + 1, train_l.mean().asnumpy(), w)) | |
if __name__ == "__main__": | |
train() | |
test = nd.array([1, 0, 0, 2, 2, 1, 0, 0, 0]) # 测试集,565441 | |
print(nd.dot(test, w)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment