Created
February 16, 2018 04:26
-
-
Save kuenishi/0a705a0f46322ef69092aedd96b376f0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import getpass | |
import os | |
import threading | |
import time | |
from urllib.parse import urlparse | |
import zipfile | |
import numpy | |
try: | |
from PIL import Image | |
from pyarrow import hdfs | |
available = True | |
except ImportError as e: | |
available = False | |
_import_error = e | |
import six | |
import chainer | |
FS = None | |
FSLOCK = threading.Lock() | |
COUNTER = 0 | |
def _read_image_as_array(path, dtype): | |
global FS, FSLOCK | |
with FSLOCK: | |
assert FS is not None | |
with FS.open(path, 'rb') as fp: | |
try: | |
f = Image.open(fp) | |
image = numpy.asarray(f, dtype=dtype) | |
return image | |
finally: | |
f.close() | |
def _read_image_inzip_as_array(zipfile, path, dtype): | |
assert zipfile is not None | |
with zipfile.open(path, 'r') as fp: | |
try: | |
f = Image.open(fp) | |
image = numpy.asarray(f, dtype=dtype) | |
return image | |
finally: | |
f.close() | |
def setup_hdfs(host, port, user=None): | |
global FS, COUNTER, FSLOCK | |
if user is None: | |
user = getpass.getuser() | |
with FSLOCK: | |
FS = hdfs.connect(host, port, user=user) | |
COUNTER += 1 | |
if FS is not None: | |
print('Connected to HDFS', host, port, 'as', user, 'at process', os.getpid(), 'counter =', COUNTER) | |
class ImageDataset(chainer.datasets.ImageDataset): | |
def __init__(self, paths, root='.', dtype=numpy.float32): | |
with FSLOCK: | |
assert FS is not None | |
super(ImageDataset, self).__init__(paths, root, dtype) | |
def get_example(self, i): | |
path = os.path.join(self._root, self._paths[i]) | |
image = _read_image_as_array(path, self._dtype) | |
if image.ndim == 2: | |
# image is greyscale | |
image = image[:, :, numpy.newaxis] | |
return image.transpose(2, 0, 1) | |
class LabeledImageDataset(chainer.datasets.LabeledImageDataset): | |
def __init__(self, pairs, root, dtype=numpy.float32, | |
label_dtype=numpy.int32): | |
with FSLOCK: | |
assert FS is not None | |
super(LabeledImageDataset, self).__init__(pairs, root, dtype, | |
label_dtype) | |
def get_example(self, i): | |
path, int_label = self._pairs[i] | |
full_path = os.path.join(self._root, path) | |
image = _read_image_as_array(full_path, self._dtype) | |
if image.ndim == 2: | |
# image is greyscale | |
image = image[:, :, numpy.newaxis] | |
label = numpy.array(int_label, dtype=self._label_dtype) | |
return image.transpose(2, 0, 1), label | |
class ZippedImageDataset(chainer.datasets.ImageDataset): | |
def __init__(self, paths, root='.', dtype=numpy.float32): | |
assert root.startswith('hdfs://') | |
super(ZippedImageDataset, self).__init__(paths, root, dtype) | |
self._root = root | |
self._pid = os.getpid() | |
self._zipfile = None | |
def get_example(self, i): | |
global FS | |
if self._pid != os.getpid() or self._zipfile is None: | |
# overhead? | |
self._url = urlparse(root) | |
self._pid = os.getpid() | |
setup_hdfs(self._url.hostname, self._url.port) | |
b = time.time() | |
self._zipfile = zipfile.ZipFile(self._root) | |
e = time.time() | |
# print(e - b, "seconds to open", self._root) | |
path = self._paths[i] | |
image = _read_image_inzip_as_array(self._zipfile, path, | |
self._dtype) | |
if image.ndim == 2: | |
# image is greyscale | |
image = image[:, :, numpy.newaxis] | |
return image.transpose(2, 0, 1) | |
def finalize(self): | |
'''Note that iterator does not finalize datasets, so use this dataset at your own risk!''' | |
if self._zipfile: | |
self._zipfile.close() | |
class ZippedLabeledImageDataset(chainer.datasets.LabeledImageDataset): | |
def __init__(self, pairs, root, dtype=numpy.float32, | |
label_dtype=numpy.int32): | |
assert root.startswith('hdfs://') | |
super(ZippedLabeledImageDataset, self).__init__(pairs, root, dtype, label_dtype) | |
self._pid = os.getpid() | |
self._hdfsfile = None | |
self._zipfile = None | |
self._timing = [] | |
def __reduce__(self): | |
return self.__class__, (self._pairs, self._root, self._dtype, self._label_dtype) | |
def get_example(self, i): | |
global FS | |
if self._pid != os.getpid() or self._zipfile is None or self._hdfsfile is None: | |
# overhead? | |
self._pid = os.getpid() | |
self._url = urlparse(self._root) | |
setup_hdfs(self._url.hostname, self._url.port) | |
with FSLOCK: | |
#b = time.time() | |
self._hdfsfile = FS.open(self._root, 'rb') | |
self._zipfile = zipfile.ZipFile(self._hdfsfile, 'r') | |
#e = time.time() | |
assert self._zipfile is not None | |
# print(e - b, "seconds to open", self._root) | |
path, int_label = self._pairs[i] | |
path = os.path.join('ILSVRC2012', path) | |
b = time.time() | |
image = _read_image_inzip_as_array(self._zipfile, path, | |
self._dtype) | |
e = time.time() | |
self._timing.append(e - b) | |
if image.ndim == 2: | |
# image is greyscale | |
image = image[:, :, numpy.newaxis] | |
label = numpy.array(int_label, dtype=self._label_dtype) | |
return image.transpose(2, 0, 1), label | |
def finalize(self): | |
'''Note that iterator does not finalize datasets, so use this dataset at your own risk!''' | |
if self._hdfsfile: | |
self._hdfsfile.close() | |
if self._zipfile: | |
self._zipfile.close() | |
def stats(self): | |
return numpy.average(self._timing), len(self._timing) | |
class ZippedLabeledImageDataset2(chainer.datasets.LabeledImageDataset): | |
''' root as zip file on hdfs, pairs are (label, internal-path) ''' | |
def __init__(self, pairs, root, dtype=numpy.float32, | |
label_dtype=numpy.int32): | |
assert root.endswith('.zip') | |
super(ZippedLabeledImageDataset2, self).__init__(pairs, root, dtype, label_dtype) | |
self._pid = os.getpid() | |
self._zipfile = None | |
self._timing = [] | |
def __reduce__(self): | |
return self.__class__, (self._pairs, self._root, self._dtype, self._label_dtype) | |
def get_example(self, i): | |
if self._pid != os.getpid() or self._zipfile is None: | |
# overhead? | |
self._pid = os.getpid() | |
#b = time.time() | |
self._zipfile = zipfile.ZipFile(self._root, 'r') | |
#e = time.time() | |
# print(e - b, "seconds to open", self._root, self._zipfile) | |
assert self._zipfile is not None | |
path, int_label = self._pairs[i] | |
path = os.path.join('ILSVRC2012', path) | |
b = time.time() | |
image = _read_image_inzip_as_array(self._zipfile, path, | |
self._dtype) | |
e = time.time() | |
self._timing.append(e - b) | |
if image.ndim == 2: | |
# image is greyscale | |
image = image[:, :, numpy.newaxis] | |
label = numpy.array(int_label, dtype=self._label_dtype) | |
return image.transpose(2, 0, 1), label | |
def finalize(self): | |
'''Note that iterator does not finalize datasets, so use this dataset at your own risk!''' | |
if self._zipfile: | |
self._zipfile.close() | |
def stats(self): | |
return numpy.average(self._timing), len(self._timing) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment