Copyright @ Kai Kang ([email protected]) 2016
Last active
June 3, 2016 09:21
-
-
Save myfavouritekk/eb15260128223993d70b815e17e30f98 to your computer and use it in GitHub Desktop.
GeneDataLayer: Python layer in Caffe for gene data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sequence_file: "data/rnac/sequences.tsv.gz" | |
target_file: "data/rnac/targets.tsv.gz" | |
set: ["A"] | |
# set: ["B"] | |
# set: ["A", "B"] | |
batch_size: 32 | |
length: 30 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import caffe | |
import yaml | |
import gzip | |
import csv | |
import math | |
import numpy as np | |
seq_map = {'A': [1, 0, 0, 0], | |
'C': [0, 1, 0, 0], | |
'G': [0, 0, 1, 0], | |
'U': [0, 0, 0, 1], | |
'T': [0, 0, 0, 1]} | |
def _process_gene_sequence(seq): | |
return np.asarray(map(lambda x:seq_map[x], seq), dtype=np.float32) | |
def _process_targets(target): | |
res = float(target) | |
if math.isnan(res): | |
return 0 | |
else: | |
return res | |
class GeneDataLayer(caffe.Layer): | |
"""docstring for GeneDataLayer""" | |
def setup(self, bottom, top): | |
layer_params = yaml.load(self.param_str) | |
config = yaml.load(open(layer_params['config']).read()) | |
seq_file = config['sequence_file'] | |
target_file = config['target_file'] | |
subset = config['set'] | |
self.sequences = [] | |
self.event_ids = [] | |
self.targets = [] | |
if seq_file.endswith('.gz') and target_file.endswith('.gz'): | |
with gzip.GzipFile(seq_file) as seqfile, gzip.GzipFile(target_file) as tarfile: | |
seq_reader = csv.reader(seqfile, delimiter='\t') | |
self.col_names = seq_reader.next() | |
tar_reader = csv.reader(tarfile, delimiter='\t') | |
self.target_names = tar_reader.next() | |
assert self.col_names[-1] == 'seq' | |
print "Reading {}...".format(seq_file) | |
sequences = [line for line in seq_reader] | |
print "Reading {}...".format(target_file) | |
targets = [line for line in tar_reader] | |
assert len(sequences) == len(targets) | |
# select subset data | |
print "Processing data..." | |
for sequence, target in zip(sequences, targets): | |
if sequence[0] not in subset: continue | |
self.event_ids.append(sequence[1]) | |
self.sequences.append(_process_gene_sequence(sequence[2])) | |
self.targets.append(np.asarray(map(_process_targets, target), dtype=np.float32)) | |
else: | |
raise NotImplementedError('{} or {} not a valid Gzip file.'.format(seq_file, target_file)) | |
batch_size = self.batch_size = config['batch_size'] | |
length = self.length = config['length'] | |
# sequence: [batch_size, 4, 1, length] | |
top[0].reshape(batch_size, 4, 1, length) | |
# targets: [batch_size, 244, 1, 1] | |
top[1].reshape(batch_size, len(self.target_names), 1, 1) | |
def _length_process(self, seq): | |
res = np.ones((self.length, 4), dtype=np.float32) * 0.25 | |
cur_length = seq.shape[0] | |
if cur_length <= self.length: | |
res[:cur_length,:] = seq | |
else: | |
# random trim | |
rand_st = np.random.randint(cur_length - self.length + 1) | |
res[...] = seq[rand_st:rand_st+self.length, :] | |
return res.T | |
def forward(self, bottom, top): | |
rand_idx = np.random.choice(np.arange(len(self.sequences)), size=self.batch_size) | |
seq = [self.sequences[idx] for idx in rand_idx] | |
target = np.asarray([self.targets[idx] for idx in rand_idx]) | |
seq = np.asarray(map(self._length_process, seq)) | |
top[0].data[...] = seq[:,:,np.newaxis,:] | |
top[1].data[...] = target[:,:,np.newaxis,np.newaxis] | |
def backward(self, bottom, top): | |
pass | |
def reshape(self, bottom, top): | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_net: "train.prototxt" | |
display: 100 | |
average_loss: 100 | |
base_lr: 0.0005 | |
lr_policy: "step" | |
gamma: 0.1 | |
max_iter: 200000 | |
stepsize: 60000 | |
momentum: 0.9 | |
weight_decay: 0.0005 | |
snapshot: 10000 | |
snapshot_prefix: "" | |
solver_mode: GPU | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
from .layer import GeneDataLayer | |
def parse_args(): | |
parser = argparse.ArgumentParser('Test Gene Data Layer.') | |
parser.add_argument('sequence_file') | |
parser.add_argument('target_file') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_args() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: "genenet" | |
layer { | |
name: 'gene_data' | |
type: 'Python' | |
top: 'sequence' | |
top: 'targets' | |
python_param { | |
module: 'gene_data_layer.layer' | |
layer: 'GeneDataLayer' | |
param_str: "config: config.yml" | |
} | |
include {phase: TRAIN} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment