Created
November 27, 2018 08:56
-
-
Save grafi-tt/be72108a5c3a8bfb3e86cfa22957b124 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
import numpy as np | |
from chainer import dataset, initializer | |
from chainer.backends import cuda | |
def cipher_ctx(key): | |
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |
from cryptography.hazmat.backends import default_backend | |
nonce = bytes(16) | |
cipher = Cipher(algorithms.AES(key), modes.CTR(nonce), | |
backend=default_backend()) | |
return cipher.encryptor() | |
class RandDataset(dataset.DatasetMixin): | |
def __init__(self, nbytes, ctx_choose, ctx_data, random): | |
self.nbytes = nbytes | |
self._ctx_choose = _ctx_choose | |
self._ctx_data = _ctx_data | |
self._random = random | |
self._i = 0 | |
def __len__(self): | |
return 40000 | |
def get_example(self, i): | |
if self._i != i: | |
raise ValueError | |
self._i += 1 | |
if self._i == len(self): | |
self._i = 0 | |
t = self._ctx_choose(b'\0') >= b'\x80' | |
if t: | |
bits = self._random.bytes(nbytes) | |
else: | |
bits = self._ctx_data(bytes(self.nbytes)) | |
bits = np.asarray(bits) | |
return bits, t | |
class GlorotUniformAES(initializer.Initializer): | |
def __init__(ctx): | |
self._ctx = ctx | |
def __call__(array): | |
if array.dtype != np.float32: | |
raise ValueError | |
uniform = self._ctx.update(bytes(4 * array.size)) | |
uniform = np.frombuffer(uniform, dtype='<u4') | |
uniform >>= 8 | |
uniform = uniform.astype(np.float32) | |
uniform *= 2 ** -24 | |
fan_in, fan_out = initalizer.get_fans(array) | |
scale = np.sqrt(6 / (fan_in + fan_out)) | |
array.ravel()[:] = data * scale | |
class RandNet(chainer.Chain): | |
def __init__(self, ctx): | |
super().__init__() | |
with self.init_scope(): | |
self.l1 = L.Convolution2D(32, 1024, ksize=(1, 8), stride=4, | |
initialW=GlorotUniformAES(ctx)) | |
self.l2 = L.Convolution2D(1024, 1024, ksize=1, stride=1, | |
initialW=GlorotUniformAES(ctx)) | |
self.l3 = L.Convolution2D(1024, 2048, ksize=(1, 8), stride=4, | |
initialW=GlorotUniformAES(ctx)) | |
self.l4 = L.Convolution2D(2048, 1024, ksize=1, stride=1, | |
initialW=GlorotUniformAES(ctx)) | |
self.l5 = L.Convolution2D(1024, 1024, ksize=(1, 8), stride=4, | |
initialW=GlorotUniformAES(ctx)) | |
self.l6 = L.Linear(1024, 1, initialW=GlorotUniformAES(ctx))) | |
def forward(self, bits, t): | |
xp = self.xp | |
batch_size = len(bits) | |
x = xp.unpackbits(bits.ravel()).astype(xp.float32) | |
x = x.reshape((batch_size, 32, 1, -1)) | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
h3 = F.relu(self.l3(h2)) | |
h4 = F.relu(self.l4(h3)) | |
h5 = F.relu(self.l5(h4)) | |
logit = self.l6(F.mean(h5, axis=(2, 3))) | |
return F.sigmoid_cross_entropy(logit, t, normalize=False) | |
def main(key, epoch, outdir): | |
ctx_weights = cipher_ctx(key + b'w') | |
ctx_choose = cipher_ctx(key + b'c') | |
ctx_data = cipher_ctx(key + b'd') | |
random = np.random.RandomState(key + b'r') | |
model = RandNet(ctx_weights) | |
dataset = RandDataset(4096, ctx_choose, ctx_data, random) | |
if __name__ == '__main__': | |
key = b'lovelovelovelovelovelovelovelov' | |
epoch = 100 | |
outdir = 'results' | |
main(key, epoch, outdir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment