Skip to content

Instantly share code, notes, and snippets.

@grafi-tt
Created November 27, 2018 08:56
Show Gist options
  • Save grafi-tt/be72108a5c3a8bfb3e86cfa22957b124 to your computer and use it in GitHub Desktop.
Save grafi-tt/be72108a5c3a8bfb3e86cfa22957b124 to your computer and use it in GitHub Desktop.
import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
from chainer import dataset, initializer
from chainer.backends import cuda
def cipher_ctx(key):
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
nonce = bytes(16)
cipher = Cipher(algorithms.AES(key), modes.CTR(nonce),
backend=default_backend())
return cipher.encryptor()
class RandDataset(dataset.DatasetMixin):
def __init__(self, nbytes, ctx_choose, ctx_data, random):
self.nbytes = nbytes
self._ctx_choose = _ctx_choose
self._ctx_data = _ctx_data
self._random = random
self._i = 0
def __len__(self):
return 40000
def get_example(self, i):
if self._i != i:
raise ValueError
self._i += 1
if self._i == len(self):
self._i = 0
t = self._ctx_choose(b'\0') >= b'\x80'
if t:
bits = self._random.bytes(nbytes)
else:
bits = self._ctx_data(bytes(self.nbytes))
bits = np.asarray(bits)
return bits, t
class GlorotUniformAES(initializer.Initializer):
def __init__(ctx):
self._ctx = ctx
def __call__(array):
if array.dtype != np.float32:
raise ValueError
uniform = self._ctx.update(bytes(4 * array.size))
uniform = np.frombuffer(uniform, dtype='<u4')
uniform >>= 8
uniform = uniform.astype(np.float32)
uniform *= 2 ** -24
fan_in, fan_out = initalizer.get_fans(array)
scale = np.sqrt(6 / (fan_in + fan_out))
array.ravel()[:] = data * scale
class RandNet(chainer.Chain):
def __init__(self, ctx):
super().__init__()
with self.init_scope():
self.l1 = L.Convolution2D(32, 1024, ksize=(1, 8), stride=4,
initialW=GlorotUniformAES(ctx))
self.l2 = L.Convolution2D(1024, 1024, ksize=1, stride=1,
initialW=GlorotUniformAES(ctx))
self.l3 = L.Convolution2D(1024, 2048, ksize=(1, 8), stride=4,
initialW=GlorotUniformAES(ctx))
self.l4 = L.Convolution2D(2048, 1024, ksize=1, stride=1,
initialW=GlorotUniformAES(ctx))
self.l5 = L.Convolution2D(1024, 1024, ksize=(1, 8), stride=4,
initialW=GlorotUniformAES(ctx))
self.l6 = L.Linear(1024, 1, initialW=GlorotUniformAES(ctx)))
def forward(self, bits, t):
xp = self.xp
batch_size = len(bits)
x = xp.unpackbits(bits.ravel()).astype(xp.float32)
x = x.reshape((batch_size, 32, 1, -1))
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
h3 = F.relu(self.l3(h2))
h4 = F.relu(self.l4(h3))
h5 = F.relu(self.l5(h4))
logit = self.l6(F.mean(h5, axis=(2, 3)))
return F.sigmoid_cross_entropy(logit, t, normalize=False)
def main(key, epoch, outdir):
ctx_weights = cipher_ctx(key + b'w')
ctx_choose = cipher_ctx(key + b'c')
ctx_data = cipher_ctx(key + b'd')
random = np.random.RandomState(key + b'r')
model = RandNet(ctx_weights)
dataset = RandDataset(4096, ctx_choose, ctx_data, random)
if __name__ == '__main__':
key = b'lovelovelovelovelovelovelovelov'
epoch = 100
outdir = 'results'
main(key, epoch, outdir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment