Created
March 21, 2019 11:09
-
-
Save nunenuh/2b3969b53b6b1ddfc3d45c8409d64f5b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import torch | |
import pathlib | |
import os | |
from torch.utils import data | |
import PIL | |
import PIL.Image | |
class SiameseDataset(data.Dataset): | |
def __init__(self, root, ext, transform=None, pair_transform=None, target_transform=None): | |
super(SiameseDataset, self).__init__() | |
self.transform = transform | |
self.pair_transform = pair_transform | |
self.target_transform = target_transform | |
self.root = root | |
self.base_path = pathlib.Path(root) | |
self.files = sorted(list(path.glob("*/*."+ext))) | |
self.files_map = self._files_mapping() | |
self.pair_files = self._pair_files() | |
def __len__(self): | |
return len(self.pair_files) | |
def __getitem__(self, idx): | |
(imp1, imp2), sim = self.pair_files[idx] | |
im1 = PIL.Image.open(imp1) | |
im2 = PIL.Image.open(imp2) | |
if self.transform: | |
im1 = self.transform(im1) | |
im2 = self.transform(im2) | |
if self.pair_transform: | |
im1,im2 = self.transform_pair(im1,im2) | |
if self.target_transform: | |
sim = self.target_transform(sim) | |
return im1, im2, sim | |
def _files_mapping(self): | |
dirname = [] | |
filename = [] | |
dct = {} | |
for f in self.files: | |
spl = str(f).split('/') | |
dirname = spl[-2] | |
filename = spl[-1] | |
if dirname not in dct.keys(): | |
dct.update({dirname:[]}) | |
else: | |
dct[dirname].append(filename) | |
dct[dirname]=sorted(dct[dirname]) | |
return dct | |
def _similar_pair(self): | |
fmap = self.files_map | |
atp = {} | |
c = 0 | |
for key in fmap.keys(): | |
atp.update({key:[]}) | |
n = len(fmap[key]) | |
ctp = ((n-1)*n)+n | |
for i in range(n): | |
for j in range(n): | |
fp = os.path.join(key, fmap[key][i]) | |
fo = os.path.join(key, fmap[key][j]) | |
atp[key].append(((fp,fo),0)) | |
return atp | |
def _len_similar_pair(self): | |
fmap = self.files_map | |
dct = {} | |
spair = self._similar_pair() | |
for key in fmap.keys(): | |
dd = {key:len(spair[key])} | |
dct.update(dd) | |
return dct | |
def _diff_pair_dircomp(self): | |
fmap = self.files_map | |
dirname = list(fmap.keys()) | |
pair_dircomp=[] | |
for idx in range(len(dirname)): | |
dirtmp = dirname.copy() | |
dirtmp.pop(idx) | |
odir = dirtmp | |
pdir = dirname[idx] | |
pdc = (pdir, odir) | |
pair_dircomp.append(pdc) | |
return pair_dircomp | |
def _different_pair(self): | |
fmap = self.files_map | |
pair_sampled = {} | |
pair_dircomp = self._diff_pair_dircomp() | |
len_spair = self._len_similar_pair() | |
for idx, (kp,kvo) in enumerate(pair_dircomp): | |
val_pri = fmap[kp] | |
num_sample = len(val_pri)//4 | |
pair_sampled.update({kp:[]}) | |
for vp in val_pri: | |
#get filename file primary | |
fp = os.path.join(kp,vp) | |
for ko in kvo: | |
vov = fmap[ko] | |
pair=[] | |
for vo in vov: | |
fo = os.path.join(ko,vo) | |
pair.append(((fp, fo),1)) | |
mout = random.sample(pair,num_sample) | |
pair_sampled[kp].append(mout) | |
for key in pair_sampled.keys(): | |
val = pair_sampled[key] | |
num_sample =len_spair[key] | |
tmp_val = [] | |
for va in val: | |
for v in va: | |
tmp_val.append(v) | |
pair_sampled[key] = random.sample(tmp_val,num_sample) | |
return pair_sampled | |
def _pair_files(self): | |
fmap = self.files_map | |
base_path = self.root | |
sim_pair = self._similar_pair() | |
diff_pair = self._different_pair() | |
files_list = [] | |
for key in fmap.keys(): | |
spair = sim_pair[key] | |
dpair = diff_pair[key] | |
n = len(spair) | |
for i in range(n): | |
spair_p = os.path.join(base_path,spair[i][0][0]) | |
spair_o = os.path.join(base_path,spair[i][0][1]) | |
spair[i] = ((spair_p, spair_o), 0) | |
dpair_p = os.path.join(base_path, dpair[i][0][0]) | |
dpair_o = os.path.join(base_path, dpair[i][0][1]) | |
dpair[i] = ((dpair_p, dpair_o), 1) | |
files_list.append(spair[i]) | |
files_list.append(dpair[i]) | |
return files_list | |
root='/data/att_faces' | |
sd = SiameseDataset(root, ext="pgm") | |
sd.__getitem__(3)[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
variable ctp gapernah dipake ya bang?