Created
April 15, 2016 12:25
-
-
Save g2384/4878b216f3437478e986682f6d6ccdf8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ntrees = 2; % number of trees | |
ratios = 1; %ratio | |
global min_sample; | |
min_sample = 2; % minimum number of samples/pixels | |
global max_depth; | |
max_depth = 15; % maximum depth | |
global num_sample; | |
num_sample = 10; % total number of samples/pixels | |
global num_image; | |
num_image = 10; % total number of samples/pixels | |
global max_w; | |
max_w = 100; % image width | |
global max_h; | |
max_h = 100; % image height | |
global max_c; | |
max_c = 2; % patch size | |
global num_leaf; | |
num_leaf = 0; | |
model = cell(1,ntrees); % forest | |
samples = [1]; | |
%model = train(0,0,0,0); | |
%% | |
% function train(depth_part,label_part,coordinate,depth_whole) | |
for i = 1:ntrees | |
tree = []; | |
%% | |
% function growTree() | |
for l = 1:num_sample | |
fprintf('hello\n') | |
tree = grow(samples, 1, 1, tree); | |
end | |
%------------------ | |
model{i} = tree; | |
end | |
%------------------ | |
%% | |
############################################# | |
function tree = grow(samples, node, depth, tree) | |
global min_sample | |
global max_depth | |
global num_leaf | |
size_samples = size(samples) | |
if(depth < max_depth) | |
measure_mode = 1; % - = classification, 1 = regression | |
if(depth < max_depth - 2) | |
measure_mode = round(rand(1),0); | |
fprintf('MeasureMode %d, depth %d',measure_mode, depth); | |
test = []; | |
A = []; | |
B = []; | |
t = 0; | |
t, A, B, test = optimizeTest(samples, size_samples, measure_mode); | |
if(t == 1) | |
tmp = test; | |
countA = 0, countB = 0; | |
tree(node) = test; | |
for l = 0:size_samples | |
fprintf('Final Split %d: A - %d, B - %d', l, size(A), size(B)); | |
end | |
if size(A)>min_sample | |
tree = grow(A, 2*node+1,depth+1, tree); | |
else | |
tree = makeLeaf(A, 2*node+1, tree,num_leaf); | |
num_leaf =num_leaf+ 1; | |
end | |
if size(B)>min_sample | |
tree = grow(B, 2*node+2,depth+1, tree); | |
else | |
tree = makeLeaf(B, 2*node+2, tree,num_leaf); | |
num_leaf =num_leaf+ 1; | |
end | |
else | |
tree = makeLeaf(B, 2*node+2, tree,num_leaf); | |
num_leaf =num_leaf+ 1; | |
end | |
else | |
tree = makeLeaf(B, 2*node+2, tree,num_leaf); | |
num_leaf =num_leaf+ 1; | |
end | |
end | |
end | |
function [t, A, B, test] = optimizeTest(samples, size_samples, measure_mode) | |
t = 0; | |
bestDist = 9999; | |
test = []; | |
A = []; | |
B = []; | |
global max_w; | |
global max_h; | |
global max_c; | |
global num_image; | |
for i = 1:num_image | |
%% | |
% generateTest() | |
test = [round(rand()*max_w), round(rand()*max_h),round(rand()*max_w), round(rand()*max_h), round(rand()*max_c)]; | |
%% | |
% evaluateTest() | |
valSet_value = [] | |
vmin = 999; | |
vmax = 0; | |
for l = 1:size_samples | |
value = samples(max_w*test(1)+test(2))-samples(max_w*test(3)+test(4)); | |
valSet_value = [valSet_value, value]; | |
if vmin > value | |
vmin = value; | |
elseif vmax < value | |
vmax = value; | |
end | |
end | |
d = vmax - vmin; | |
if d>0 | |
for j = 1:10 | |
tr = round(rand(d)) + vmin; | |
%% | |
% split() | |
t = valSet_value<tr; | |
[mx,ix] = max(valSet_value(t)); | |
A = samples(1:ix); | |
B = samples(ix+1:size_samples); | |
A_size = size(A) | |
B_size = size(B) | |
if(A_size > 0 && B_size > 0) | |
%% | |
% measureSet() | |
%if measure_mode == 0 | |
% info gain | |
cdist = histc(A, unique(A)) + 1; | |
cdist = cdist/sum(cdist); | |
cdist = cdist .* log(cdist); | |
n_A = -sum(cdist); | |
cdist = histc(B, unique(B)) + 1; | |
cdist = cdist/sum(cdist); | |
cdist = cdist .* log(cdist); | |
n_B = -sum(cdist); | |
tmpDist = (A_size*n_A+B_size*n_B)/(A_size+B_size); | |
%else | |
%end | |
if tmpDist > bestDist | |
f = 1; | |
bestDist = tmpDist; | |
test = [test, tr]; | |
end | |
end | |
end | |
end | |
end | |
end | |
function tree = makeLeaf(set, node, tree,num_leaf) | |
tree(node) = num_leaf; | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment