Created
October 19, 2018 11:44
-
-
Save koen-dejonghe/2b36b739e65d8fe8fdfd7f9b2f09aa1f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forwardHandle: Receive = { | |
case Forward(x) => // accept forward messages | |
val result = module(x) // calculate the result | |
wire.next ! Forward(result) // forward the result | |
context become backwardHandle(x, result) // change state to backward | |
} | |
def backwardHandle(input: Variable, output: Variable): Receive = { | |
case Backward(g) => // accept backward messages | |
optimizer.zeroGrad() // reset gradients to 0 | |
output.backward(g) // calculate local gradients | |
wire.prev ! Backward(input.grad) // send local gradients | |
optimizer.step() // update the weights | |
context become forwardHandle // change state to forward | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment