Skip to content

Instantly share code, notes, and snippets.

@carlosh93
Last active July 13, 2020 15:30
Show Gist options
  • Save carlosh93/d1ab46d395d8483425548569bcba7d32 to your computer and use it in GitHub Desktop.
Save carlosh93/d1ab46d395d8483425548569bcba7d32 to your computer and use it in GitHub Desktop.
HSI Classification #matlab
function [train_SL, test_SL] = data_training(varargin)
if nargin < 2
error('Insuficient number of inputs');
end
if nargin >= 2
data_lbls = varargin{1};
num = varargin{2};
end
if nargin > 2
percFlag = varargin{3};
else
percFlag = true;
end
[no_rows,no_lines] = size(data_lbls);
no_classes = length(unique(data_lbls(:)))-1;
if percFlag
for ii = 1: no_classes
RandSampled_Num(ii) = round((num/100)*length(find(data_lbls(:)==ii)));
end
end
%% Select the number of training samples for each class
%RandSampled_Num = [5 143 83 24 48 73 3 48 2 97 246 59 21 127 39 9];
Nonzero_map = zeros(no_rows,no_lines);
Nonzero_index = find(data_lbls ~= 0);
Nonzero_map(Nonzero_index)=1;
%% Create the experimental set based on groundtruth of HSI
Train_Label = [];
Train_index = [];
for ii = 1: no_classes
index_ii = find(data_lbls == ii);
class_ii = ones(length(index_ii),1)* ii;
Train_Label = [Train_Label class_ii'];
Train_index = [Train_index index_ii'];
end
trainall = zeros(2,length(Train_index));
trainall(1,:) = Train_index;
trainall(2,:) = Train_Label;
%% Create the Training set with randomly sampling 3-D Dataset and its correponding index
indexes =[];
for i = 1: no_classes
W_Class_Index = find(Train_Label == i);
Random_num = randperm(length(W_Class_Index));
Random_Index = W_Class_Index(Random_num);
Tr_Index = Random_Index(1:RandSampled_Num(i));
indexes = [indexes Tr_Index];
end
indexes = indexes';
train_SL = trainall(:,indexes);
%train_samples = img1(:,train_SL(1,:))';
train_labels= train_SL(2,:)';
%% Create the Testing set with randomly sampling 3-D Dataset and its correponding index
test_SL = trainall;
test_SL(:,indexes) = [];
%test_samples = img1(:,test_SL(1,:))';
test_labels = test_SL(2,:)';
train_SL = train_SL';
test_SL = test_SL';
end
function rstl = evaluate_results(labs_est, labs)
% Evaluate Classification Results
%
% Syntax:
% [kappa acc acc_O acc_A] = evaluate_results(tlabs, Trlabs)
%
% Input:
% labs_est: M-by-1 vector of labels given to test data
% labs : M-by-1 vector of ground truth labels
%
% Output:
% kappa: kappa coefficient
% acc: accuracy per class
% acc_o: overall accuracy
% acc_a: average accuracy
%
if (max(labs) - min(labs)) ==13
% c = max(labs) - min(labs) + 2; %+1 +2
c = length(unique(labs));
labs(labs==15)=1;
labs_est(labs_est==15)=1;
else
c = max(labs) - min(labs) + 1; %+1 +2
end
% make confusion matrix
CM = zeros(c,c);
for i = 1:c
for j = 1:c
CM(i,j) = sum(labs_est==i & labs==j);
end
end
% Class accuracy
acc = zeros(c, 1);
for j = 1:c
acc(j) = CM(j,j)/sum(CM(:,j));
end
% Overall and average accuracy
acc_o = sum(diag(CM))/sum(sum(CM));
acc_a = mean( acc );
% Kappa coefficient of agreement
kappa = (acc_o - sum( sum(CM,1)*sum(CM,2) )/sum(sum(CM)).^2)...
/(1 - sum( sum(CM,1)*sum(CM,2) )/sum(sum(CM)).^2);
rstl.acc = acc;
rstl.acc_o = acc_o;
rstl.acc_a = acc_a;
rstl.kappa = kappa;
% Se carga la data
load('Salinas_corrected.mat');
load('Salinas_gt.mat');
Fullimg = salinas_corrected;
Fullimg = Fullimg(:,1:216,:);
data_lbls = salinas_gt(:,1:216);
clear salinas_corrected salinas_gt
Nm = 216;
Mm = 512;
Lh = 204;
Fori = reshape(Fullimg,Nm*Mm,Lh); %Fori es la imagen HSI en forma matricial M*N(filas) x L (columnas)
trng = 0.2; %Porcentaje de la data usada para training
% data_lbls son las etiquetas o groundtruth
[trn_indx,~] = data_training(data_lbls,trng,true); %test_indx
trn_indx = trn_indx(:,1);
%% Classification with full data
%Knn - Data ori
tic;
tmp1 = normc(Fori')'; %
tmp2 = normc(Fori(trn_indx,:)')'; %trn_indx son los indices de los pixeles de entrenamiento
Md1 = fitcknn(tmp2,data_lbls(trn_indx),'NumNeighbors',1); %aca se hace el entrenamiento con KNN
est_labs = predict(Md1,tmp1); % Clasificacion usando el clasificador ya entrenado Md1. Se clasifica todos los pixeles de HSI
knn_ori_time = toc;
rstl_ori_knn = evaluate_results(est_labs(data_lbls~=0),... % funcion para evaluar resultados
data_lbls(data_lbls~=0));
rstl_ori_knn.time = knn_ori_time;
full_labs_ori_knn = reshape(est_labs,Mm,Nm);
%SVM
tic; % start counting time for SVM
tmp1 = normc(Fori')';
tmp2 = normc(Fori(trn_indx,:)')';
t = templateSVM('Standardize',1,'BoxConstraint',3);
Md1 = fitcecoc(tmp2,data_lbls(trn_indx),'Learners',t); % Entrenamiento SVM
est_labs = predict(Md1,tmp1);
svmtime = toc;
rstl_svm = evaluate_results(est_labs(data_lbls~=0),...
data_lbls(data_lbls~=0));
full_labs_svm = reshape(est_labs,Mm,Nm);
rstl_svm.time = svmtime;
% Ver resultados
fprintf("Resultados Knn: \n");
display(rstl_ori_knn);
fprintf("Resultados SVM: \n");
display(rstl_svm);
% Acc son los resultados de accuracy por clase
% Acc_o es el overall accuracy
% Acc_a es el average accuracy
% Kappa son los coeficientes Kappa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment