Last active
August 29, 2015 14:07
-
-
Save muupan/93babddb73d25f544fb2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <memory> | |
#include <random> | |
#include <caffe/caffe.hpp> | |
#include <glog/logging.h> | |
int main(int argc, char** argv) { | |
// glogの初期化 | |
google::InitGoogleLogging(argv[0]); | |
// 教師データとして用いる入力データと目標データをfloat配列として準備する. | |
// 入力データ:2次元 | |
// 目標データ:1次元 | |
constexpr auto kMinibatchSize = 32; | |
constexpr auto kDataSize = kMinibatchSize * 10; | |
std::array<float, kDataSize * 2> input_data; | |
std::array<float, kDataSize> target_data; | |
std::mt19937 random_engine; | |
std::uniform_real_distribution<> dist(0.0, 1.0); | |
// 3x - 2y + 4 = target に従ってデータを生成する. | |
for (auto i = 0; i < kDataSize; ++i) { | |
const auto x = dist(random_engine); | |
const auto y = dist(random_engine); | |
const auto target = 3 * x - 2 * y + 4; | |
input_data[i * 2] = x; | |
input_data[i * 2 + 1] = y; | |
target_data[i] = target; | |
} | |
// MemoryDataLayerはメモリ上の値を出力できるDataLayer. | |
// 各MemoryDataLayerには入力データとラベルデータ(1次元の整数)の2つを与える必要があるが, | |
// ここでは回帰を行いたいので,入力データと目標データそれぞれを別のMemoryDataLayerで出力し, | |
// ラベルデータの代わりに使用されないダミーの値を与えておく. | |
std::array<float, kDataSize> dummy_data; | |
std::fill(dummy_data.begin(), dummy_data.end(), 0.0); | |
// Solverの設定をテキストファイルから読み込む | |
caffe::SolverParameter solver_param; | |
caffe::ReadProtoFromTextFileOrDie("solver.prototxt", &solver_param); | |
const auto solver = | |
std::shared_ptr<caffe::Solver<float>>( | |
caffe::GetSolver<float>(solver_param)); | |
const auto net = solver->net(); | |
// 入力データをMemoryDataLayer"input"にセットする | |
const auto input_layer = | |
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>( | |
net->layer_by_name("input")); | |
assert(input_layer); | |
input_layer->Reset(input_data.data(), dummy_data.data(), kDataSize); | |
// 目標データをMemoryDataLayer"target"にセットする | |
const auto target_layer = | |
boost::dynamic_pointer_cast<caffe::MemoryDataLayer<float>>( | |
net->layer_by_name("target")); | |
assert(target_layer); | |
target_layer->Reset(target_data.data(), dummy_data.data(), kDataSize); | |
// Solverの設定通りに学習を行う | |
solver->Solve(); | |
// 学習されたパラメータを出力してみる | |
// ax + by + c = target | |
const auto ip_blobs = net->layer_by_name("ip")->blobs(); | |
const auto learned_a = ip_blobs[0]->cpu_data()[0]; | |
const auto learned_b = ip_blobs[0]->cpu_data()[1]; | |
const auto learned_c = ip_blobs[1]->cpu_data()[0]; | |
std::cout << learned_a << "x + " << learned_b << "y + " << learned_c | |
<< " = target" << std::endl; | |
// 学習されたモデルを使って予測してみる | |
// x = 10, y = 20 | |
std::array<float, kDataSize * 2> sample_input; | |
sample_input[0] = 10; | |
sample_input[1] = 20; | |
input_layer->Reset(sample_input.data(), dummy_data.data(), kDataSize); | |
net->ForwardPrefilled(nullptr); | |
std::cout << "10a + 20b + c = " << net->blob_by_name("ip")->cpu_data()[0] << std::endl; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
layers { | |
name: "input" | |
type: MEMORY_DATA | |
top: "input" | |
top: "dummy_label1" | |
memory_data_param { | |
batch_size: 32 | |
channels: 2 | |
height: 1 | |
width: 1 | |
} | |
} | |
layers { | |
name: "ip" | |
type: INNER_PRODUCT | |
bottom: "input" | |
top: "ip" | |
inner_product_param { | |
num_output: 1 | |
weight_filler { | |
type: "constant" | |
value: 0 | |
} | |
bias_filler { | |
type: "constant" | |
value: 0 | |
} | |
} | |
} | |
layers { | |
name: "target" | |
type: MEMORY_DATA | |
top: "target" | |
top: "dummy_label2" | |
memory_data_param { | |
batch_size: 32 | |
channels: 1 | |
height: 1 | |
width: 1 | |
} | |
} | |
layers { | |
name: "loss" | |
type: EUCLIDEAN_LOSS | |
bottom: "ip" | |
bottom: "target" | |
top: "loss" | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
net: "net.prototxt" | |
solver_type: SGD | |
base_lr: 0.01 | |
lr_policy: "step" | |
gamma: 0.1 | |
stepsize: 1000 | |
max_iter: 4000 | |
momentum: 0.9 | |
display: 1000 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment