Skip to content

Instantly share code, notes, and snippets.

@magalhini
Created September 6, 2017 17:57
Show Gist options
  • Save magalhini/eebb04fd975e170bd8bf49b762e44287 to your computer and use it in GitHub Desktop.
Save magalhini/eebb04fd975e170bd8bf49b762e44287 to your computer and use it in GitHub Desktop.
Training a neural network to recognise 15 inputs worth of Shapes
const Layer = require('synaptic').Layer;
const Network = require('synaptic').Network;
const Trainer = require('synaptic').Trainer;
const inputLayer = new Layer(15);
const hiddenLayer = new Layer(40);
const outputLayer = new Layer(4);
inputLayer.project(hiddenLayer);
hiddenLayer.project(outputLayer);
const myNetwork = new Network({
input: inputLayer,
hidden: [hiddenLayer],
output: outputLayer
});
const shapesMap = {
0: 'RECTANGLE',
1: 'TRIANGLE',
2: 'SQUARE',
3: 'CIRCLE'
};
let trainingData = [
{ // Rectangle
input: [1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
output: [1, 0, 0, 0]
},
{ // Triangle
input: [0,0,1,0,0, 0,1,0,1,0, 1,1,1,1,1],
output: [0, 1, 0, 0]
},
{ // Square
input: [1,1,1,0,0, 1,0,1,0,0, 1,1,1,0,0],
output: [0, 0, 1, 0]
},
{ // Circle
input: [0,0,1,0,0, 0,1,0,1,0, 0,0,1,0,0],
output: [0, 0, 0, 1]
},
{ // Square 2
input: [0,1,1,1,0, 0,1,0,1,0, 0,1,1,1,0],
output: [0, 0, 1, 0]
},
];
const toChunks = function(array, size) {
var results = [];
while (array.length) results.push(array.splice(0, size));
return results;
};
const trainer = new Trainer(myNetwork);
trainer.train(trainingData, {
rate: .2,
cost: Trainer.cost.CROSS_ENTROPY,
iterations: 200,
log: 10,
});
const badlyDrawnSquare = [0,0,1,1,0, 0,1,0,1,0, 0,1,1,1,0];
console.log('---------- Testing training data --------- ')
console.log(myNetwork.activate(trainingData[0].input));
const badSquarePrediction = myNetwork.activate(badlyDrawnSquare);
const index = badSquarePrediction.indexOf(Math.max(...badSquarePrediction));
const percentage = Math.max(...badSquarePrediction);
const draw = badlyDrawnSquare
.map((item, i) => item === 1 ? 'x' : ' ');
console.log(toChunks(draw, 5))
console.log(`Prediction is ${shapesMap[index]} with ${(percentage * 100).toFixed(2)}% accuracy` );
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment