Created
June 19, 2018 16:10
-
-
Save mitmul/32f6e9b000f1acb8aa118c41afec1a14 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import argparse | |
import collections | |
import os | |
import shutil | |
import numpy as np | |
import pandas as pd | |
import tabulate | |
def create_dataset(original_excel_fn): | |
info = pd.read_excel(original_excel_fn) | |
labels = [] | |
data = [] | |
product_name_en = { | |
'ポテトチップス': 'PotatoChips', | |
} | |
size_en = { | |
'BIGBAG': 'BIGBAG', | |
'Lサイズ': 'L_Size', | |
'レギュラー': 'Regular' | |
} | |
flavor_en = { | |
'ウスシオ': 'LightSalt', | |
'コンソメパンチ': 'Consomme', | |
'キュウシュウショウユ': 'KyushuSoySauce' | |
} | |
class_id = { | |
4901330502911: 0, | |
4901330502928: 1, | |
4901330503284: 2, | |
4901330523121: 3, | |
4901330523176: 4, | |
4901330523183: 5, | |
4901330532734: 6, | |
4901330532918: 7, | |
4901330534516: 8 | |
} | |
for i, row in info.iterrows(): | |
fn, jan, product, _, size, flavor = row | |
fn = os.path.basename(fn) | |
head, shape, vertical_angle, horizontal_angle = os.path.splitext(fn)[0].split('_') | |
if 'D' in vertical_angle: | |
vertical_angle = -int(vertical_angle.replace('D', '')) | |
elif 'U' in vertical_angle: | |
vertical_angle = int(vertical_angle.replace('U', '')) | |
if str(head) != str(jan): | |
print(head, '!=', jan) | |
continue | |
data.append({ | |
'filename': fn, | |
'jan': jan, | |
'class_id': class_id[jan], | |
'shape_id': shape, | |
'vertical_angle': vertical_angle, | |
'horizontal_angle': int(horizontal_angle), | |
'product_name': product_name_en[product], | |
'size': size_en[size], | |
'flavor': flavor_en[flavor] | |
}) | |
labels = pd.DataFrame(data) | |
return labels | |
def mkdir(dname): | |
if not os.path.exists(dname): | |
os.makedirs(dname) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--n-shapes-for-train', type=int, default=27) | |
parser.add_argument('--n-valid-examples', type=int, default=2000) | |
parser.add_argument('--seed', type=int, default=0) | |
parser.add_argument('--output-dir', type=str, default='data') | |
parser.add_argument('--image-dir', type=str, default='data/images') | |
parser.add_argument('--original-excel-filename', type=str, default='hackathon data/ファイル情報.xlsx') | |
args = parser.parse_args() | |
# Parse the given excel file | |
labels = create_dataset(args.original_excel_filename) | |
np.random.seed(args.seed) | |
n_shapes = len(np.unique(labels.loc[:, 'shape_id'])) | |
n_shapes_for_train = args.n_shapes_for_train | |
shape_ids = np.arange(1, n_shapes + 1) | |
np.random.shuffle(shape_ids) | |
shapes_train = shape_ids[:n_shapes_for_train] | |
train_dir = os.path.join(args.output_dir, 'train') | |
mkdir(train_dir) | |
train_img_dir = os.path.join(train_dir, 'images') | |
mkdir(train_img_dir) | |
fp_train_labels = open(os.path.join(train_dir, 'train_labels.txt'), 'w') | |
valtest_labels = [] | |
train_class_balance = collections.defaultdict(int) | |
for _, row in labels.iterrows(): | |
if int(row['shape_id']) in shapes_train: | |
print('{} {}'.format(row['filename'], row['class_id']), file=fp_train_labels) | |
train_class_balance[row['class_id']] += 1 | |
img_fn = os.path.join(args.image_dir, row['filename']) | |
shutil.copy(img_fn, train_img_dir) | |
else: | |
valtest_labels.append([row['filename'], row['class_id']]) | |
fp_train_labels.close() | |
print('Train calss balance:') | |
print(tabulate.tabulate( | |
train_class_balance.items(), headers=('class_id', 'frequency'))) | |
# Split valtest into val and test | |
valid_dir = os.path.join(args.output_dir, 'valid') | |
mkdir(valid_dir) | |
valid_img_dir = os.path.join(valid_dir, 'images') | |
mkdir(valid_img_dir) | |
fp_valid_labels = open(os.path.join(valid_dir, 'valid_labels.txt'), 'w') | |
test_dir = os.path.join(args.output_dir, 'test') | |
mkdir(test_dir) | |
test_img_dir = os.path.join(test_dir, 'images') | |
mkdir(test_img_dir) | |
fp_test_labels = open(os.path.join(test_dir, 'test_labels.txt'), 'w') | |
np.random.shuffle(valtest_labels) | |
valid_class_balance = collections.defaultdict(int) | |
test_class_balance = collections.defaultdict(int) | |
for i, (fn, class_id) in enumerate(valtest_labels): | |
img_fn = os.path.join(args.image_dir, fn) | |
if i < args.n_valid_examples: | |
valid_class_balance[class_id] += 1 | |
if os.path.exists(img_fn): | |
print('{} {}'.format(fn, class_id), file=fp_valid_labels) | |
shutil.copy(img_fn, valid_img_dir) | |
else: | |
test_class_balance[class_id] += 1 | |
if os.path.exists(img_fn): | |
print('{} {}'.format(fn, class_id), file=fp_test_labels) | |
shutil.copy(img_fn, test_img_dir) | |
print('\nValid class balance:') | |
print(tabulate.tabulate( | |
valid_class_balance.items(), headers=('class_id', 'frequency'))) | |
print('\nTest class balance:') | |
print(tabulate.tabulate( | |
test_class_balance.items(), headers=('class_id', 'frequency'))) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment