Last active
March 28, 2017 19:41
-
-
Save Trion129/bdad06ad8b2c9fee589f0ccf0aba101b to your computer and use it in GitHub Desktop.
Agents
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "environment.hpp" | |
#include <iostream> | |
#include "NaiveAgent.hpp" | |
using namespace gym; | |
int main(int argc, char* argv[]) | |
{ | |
const std::string environment = "SpaceInvaders-v0"; | |
const std::string host = "kurg.org"; | |
const std::string port = "4040"; | |
double totalReward = 0; | |
size_t totalSteps = 0; | |
Environment env(host, port, environment); | |
env.compression(9); | |
env.monitor.start("./dummy/", true, true); | |
env.reset(); | |
env.render(); | |
// while (1) | |
// { | |
// arma::mat action = env.action_space.sample(); | |
// std::cout << "action: \n" << action << std::endl; | |
// env.step(action); | |
// totalReward += env.reward; | |
// totalSteps += 1; | |
// if (env.done) | |
// { | |
// break; | |
// } | |
// std::cout << "Current step: " << totalSteps << " current reward: " | |
// << totalReward << std::endl; | |
// } | |
Agent agent(210, 160); | |
agent.Play(env, 0.2); | |
std::cout << "Instance: " << env.instance << " total steps: " << totalSteps | |
<< " reward: " << totalReward << std::endl; | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cmath> | |
#include <mlpack/core.hpp> | |
#include <mlpack/core/optimizers/sgd/sgd.hpp> | |
#include <mlpack/methods/ann/layer/layer.hpp> | |
#include <mlpack/methods/ann/ffn.hpp> | |
#include <mlpack/methods/ann/layer/linear.hpp> | |
#include <mlpack/methods/ann/layer/dropout.hpp> | |
#include <mlpack/methods/ann/layer/leaky_relu.hpp> | |
#include <mlpack/methods/ann/layer/convolution.hpp> | |
using namespace mlpack; | |
using namespace optimization; | |
using namespace ann; | |
using namespace gym; | |
class Agent{ | |
public: | |
FFN<> model; | |
Agent(size_t inputW, size_t inputH){ | |
// 1 stream to 10 streams of depth, filter 5x5 | |
model.Add<Convolution<>>(1, 10, 5, 5, 1, 1, 0, 0, inputW, inputH); | |
model.Add<LeakyReLU<>>(); | |
// 10 depth to 15 depth, filter 5x5, strides of 2 | |
model.Add<Convolution<>>(10, 15, 5, 5, 2, 2); | |
model.Add<LeakyReLU<>>(); | |
model.Add<Linear<>>(115140, 700); | |
model.Add<LeakyReLU<>>(); | |
// Fully Connected from 700 -> 3 nodes | |
model.Add<Linear<>>(700, 3); | |
} | |
void Play(Environment& env, double explore){ | |
arma::mat result, actionMat, frame; | |
double maxRewardAction; | |
double totalReward = 0; | |
size_t totalSteps = 0; | |
while(1){ | |
//Get observation | |
frame = arma::vectorise(env.observation); | |
//Predict a reward for each action in current frame | |
model.Predict(frame, result); | |
if(arma::randu() > explore){ | |
//Pick the action with maximum reward | |
maxRewardAction = arma::mat(result.t()).index_max(); | |
} | |
else{ | |
//Pick random action number | |
maxRewardAction = floor(arma::randu() * 3); | |
} | |
//Make 1 hot vector for chosen action | |
actionMat = arma::zeros<arma::mat>(1, 3); | |
actionMat[maxRewardAction] = 1; | |
env.step(actionMat); | |
// New Matrix for correct result | |
// replace current reward with correct one | |
arma::mat correctResult(result); | |
correctResult[maxRewardAction] = env.reward; | |
// Optimize for the correct result given same frame | |
std::cout << "Corrected result:"; | |
for(int i : correctResult){ | |
std::cout << i << " "; | |
} | |
model.Train(std::move(frame), std::move(correctResult)); | |
if(env.done){ | |
break; | |
} | |
totalReward += env.reward; | |
totalSteps += 1; | |
std::cout << "Current step: " << totalSteps << " current reward: " | |
<< totalReward << std::endl; | |
} | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some comments:
So at the end it should look like:
Let me know if that was helpful.