Last active
October 27, 2017 00:35
-
-
Save crockpotveggies/a061311b88cf21d3e662f7834f3a9b03 to your computer and use it in GitHub Desktop.
Sequence to Sequence Autoencoder Preprocessor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Preprocessor extends Serializable { | |
class Seq2SeqAutoencoderPreProcessor extends MultiDataSetPreProcessor { | |
override def preProcess(mds: MultiDataSet): Unit = { | |
val input: INDArray = mds.getFeatures(0) | |
val features: Array[INDArray] = Array.ofDim[INDArray](2) | |
val labels: Array[INDArray] = Array.ofDim[INDArray](1) | |
features(0) = input | |
val mb: Int = input.size(0) | |
val nClasses: Int = input.size(1) | |
val origMaxTsLength: Int = input.size(2) | |
val goStopTokenPos: Int = nClasses | |
//1 new class, for GO/STOP. And one new time step for it also | |
val newShape: Array[Int] = Array(mb, nClasses + 1, origMaxTsLength + 1) | |
features(1) = Nd4j.create(newShape:_*) | |
labels(0) = Nd4j.create(newShape:_*) | |
//Create features. Append existing at time 1 to end. Put GO token at time 0 | |
features(1).put(Array[INDArrayIndex](all(), interval(0, input.size(1)), interval(1, newShape(2))), input) | |
//Set GO token | |
features(1).get(all(), point(goStopTokenPos), all()).assign(1) | |
//Create labels. Append existing at time 0 to end-1. Put STOP token at last time step - **Accounting for variable length / masks** | |
labels(0).put(Array[INDArrayIndex](all(), interval(0, input.size(1)), interval(0, newShape(2) - 1)), input) | |
var lastTimeStepPos: Array[Int] = null | |
if (mds.getFeaturesMaskArray(0) == null) {//No masks | |
lastTimeStepPos = Array.ofDim[Int](input.size(0)) | |
for (i <- 0 until lastTimeStepPos.length) { | |
lastTimeStepPos(i) = input.size(2) - 1 | |
} | |
} else { | |
val fm: INDArray = mds.getFeaturesMaskArray(0) | |
val lastIdx: INDArray = BooleanIndexing.lastIndex(fm, Conditions.notEquals(0), 1) | |
lastTimeStepPos = lastIdx.data().asInt() | |
} | |
for (i <- 0 until lastTimeStepPos.length) { | |
labels(0).putScalar(i, goStopTokenPos, lastTimeStepPos(i), 1.0) | |
} | |
//In practice: Just need to append an extra 1 at the start (as all existing time series are now 1 step longer) | |
var featureMasks: Array[INDArray] = null | |
var labelsMasks: Array[INDArray] = null | |
if (mds.getFeaturesMaskArray(0) != null) {//Masks are present - variable length | |
featureMasks = Array.ofDim[INDArray](2) | |
featureMasks(0) = mds.getFeaturesMaskArray(0) | |
labelsMasks = Array.ofDim[INDArray](1) | |
val newMask: INDArray = Nd4j.hstack(Nd4j.ones(mb, 1), mds.getFeaturesMaskArray(0)) | |
// println(mds.getFeaturesMaskArray(0).shape()) | |
// println(newMask.shape()) | |
featureMasks(1) = newMask | |
labelsMasks(0) = newMask | |
} else { | |
//All same length | |
featureMasks = null | |
labelsMasks = null | |
} | |
//Same for labels | |
mds.setFeatures(features) | |
mds.setLabels(labels) | |
mds.setFeaturesMaskArrays(featureMasks) | |
mds.setLabelsMaskArray(labelsMasks) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment