Created
August 4, 2017 15:15
-
-
Save hengzhe-zhang/4404fa3203cf1b24740e4131971c7939 to your computer and use it in GitHub Desktop.
基于ELM的图片多分类器
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 基于ELM的图片多分类器 | |
import os | |
import cv2 | |
import hpelm | |
import numpy as np | |
class FishElm(object): | |
sample_num = 0 # 每种照片样本数 | |
photo_dir = '' # 照片路径 | |
kind_list = [] # 种类名称 | |
def __init__(self, sample_num, photo_dir, kind_list): | |
self.sample_num = sample_num | |
self.photo_dir = photo_dir | |
self.kind_list = kind_list | |
self.elm=self.train() | |
''' | |
函数说明: | |
读取某一目录下的图片,并将其灰度化 | |
由于爬取的图片存在部分图片无法下载的情况,部分编号缺失,因此需要进行尝试读取操作 | |
参数说明: | |
kind(String):种类名称 | |
from_id(String):起始图片编号 | |
all_num(String):编号 | |
返回值: | |
imglist(String):图片灰度信息列表,列表中每一个元素为一张图片的灰度信息 | |
''' | |
def read_as_list(self, kind, from_id, all_num): | |
imglist = [] # 训练数据列表 | |
nownum = 0 # 当前已处理的图片数量 | |
while nownum < all_num: | |
file_path = os.path.join(self.photo_dir, os.path.join(kind, '{}.jpg'.format(from_id))) | |
from_id += 1 | |
if os.path.exists(file_path): | |
try: | |
manimg = cv2.resize(cv2.imread(file_path, cv2.IMREAD_GRAYSCALE), (100, 100), | |
interpolation=cv2.INTER_CUBIC) | |
nownum += 1 | |
manarray = np.array(bytearray(manimg)) | |
imglist.append(manarray) | |
except: | |
print('图片{}读取失败'.format(from_id)) | |
return imglist | |
''' | |
函数说明:利用ELM进行训练 | |
返回值: | |
ELM:训练好的ELM | |
''' | |
def train(self): | |
elm = hpelm.ELM(10000, len(self.kind_list)) | |
# 生成训练数据 | |
input_data = [] | |
output_data = [] | |
pos = 0 # 当前处理的种类编号 | |
for kind in self.kind_list: | |
input_data.extend(self.read_as_list(kind, 0, self.sample_num)) | |
for sample_temp in range(self.sample_num): | |
output_data.append([i == pos for i in range(len(self.kind_list))]) | |
pos += 1 | |
# 添加神经元 | |
elm.add_neurons(30, 'lin') | |
elm.add_neurons(15, 'rbf_linf') | |
# 训练 | |
elm.train(np.array(input_data), np.array(output_data)) | |
return elm | |
''' | |
函数说明:利用ELM进行预测 | |
参数: | |
kind_name:图片种类名称 | |
photo_num:预测图片编号 | |
elm:训练好的ELM | |
返回值: | |
(图片种类编号,概率) | |
异常情况: | |
图片读取异常返回(-1,-1) | |
''' | |
def predict_photo(self, kind_name, photo_num): | |
try: | |
# 准备测试数据 | |
testim = cv2.resize( | |
cv2.imread(os.path.join(self.photo_dir, os.path.join(kind_name, '{}.jpg'.format(photo_num))), | |
cv2.IMREAD_GRAYSCALE), | |
(100, 100), | |
interpolation=cv2.INTER_CUBIC) | |
testarray = np.array(bytearray(testim)) | |
test_data = np.array([testarray.tolist()]) | |
# 预测 | |
predict = self.elm.predict(test_data) | |
prbobility = 0 # 照片为某一种类的可能性 | |
max_photo = 0 # 最可能的照片种类 | |
for i in range(len(predict[0])): | |
if prbobility < predict[0][i]: | |
prbobility = predict[0][i] | |
max_photo = i | |
return (max_photo + 1, prbobility) | |
except: | |
return (-1, -1) | |
''' | |
函数说明:性能测试 | |
返回值: | |
begin_num:参与性能测试的图片起始位置 | |
end_num:参与性能测试的图片终止位置 | |
''' | |
def test_performance(self, begin_num, end_num): | |
pos = 1 # 记录当前处理的图片种类编号 | |
true_num = 0 | |
for kind in self.kind_list: | |
for i in range(begin_num, end_num + 1): | |
ans = self.predict_photo(kind, i) | |
if ans == (-1, -1): | |
continue | |
if ans[0] == pos: | |
true_num += 1 | |
pos += 1 | |
return true_num / ((end_num - begin_num + 1) * len(self.kind_list)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment