Created
July 20, 2015 21:17
Revisions
-
kingjr created this gist
Jul 20, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,152 @@ import numpy as np import mne from mne.decoding import GeneralizationAcrossTime as GAT from sklearn.metrics import roc_auc_score from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC from sklearn.pipeline import make_pipeline from sklearn.cross_validation import StratifiedKFold from meeg_preprocessing.utils import setup_provenance import matplotlib.pyplot as plt from pyriemann.estimation import ERPCovariances from pyriemann.tangentspace import TangentSpace report, run_id, results_dir, logger = setup_provenance( __file__, results_dir='results') conditions = ['LSGS', 'LSGD', 'LDGS', 'LDGD'] local_cond, global_cond = [ [['LSGS', 'LSGD'], ['LDGS', 'LDGD']], [['LSGS', 'LDGS'], ['LSGD', 'LDGD']]] class reshape_X(object): def __init__(self, window=1): self.window = window # in time sample def fit(self, X, y=None): return self def fit_transform(self, X, y=None): return self.transform(X) def transform(self, X): if X.shape[1] / self.window != X.shape[1] / float(self.window): raise ValueError('Wrong window size') return X.reshape([X.shape[0], X.shape[1] / self.window, self.window]) class force_predict(object): def __init__(self, clf, mode='predict_proba', axis=0): self._mode = mode self._axis = axis self._clf = clf def fit(self, X, y, **kwargs): self._clf.fit(X, y, **kwargs) self._copyattr() def predict(self, X): if self._mode == 'predict_proba': return self._clf.predict_proba(X)[:, self._axis] elif self._mode == 'decision_function': distances = self._clf.decision_function(X) if len(distances.shape) > 1: return distances[:, self._axis] else: return distances else: return self._clf.predict(X) def get_params(self, deep=True): return dict(clf=self._clf, mode=self._mode, axis=self._axis) def _copyattr(self): for key, value in self._clf.__dict__.iteritems(): self.__setattr__(key, value) class force_weight(object): def __init__(self, clf, weights=None): self._clf = clf def fit(self, X, y): return self._clf.fit(X, np.array(y[:, 0], dtype=int), sample_weight=np.array(y[:, 1])) def predict(self, X): return self._clf.predict(X) def get_params(self, deep=True): return dict(clf=self._clf) # weighted probablistic linear classifier clf = make_pipeline(StandardScaler(), force_predict(force_weight(SVC( kernel='linear', probability=True)), axis=1)) results = list() for subject_name in ['TAJ20081223']: # Preproc epochs = mne.read_epochs('TAJ-epo.fif') # this_epochs = epochs[conditions].crop(0.6, None) this_epochs = epochs[conditions].crop(0.750, None) # Contrast definitions event_id = {v: k for k, v in this_epochs.event_id.items()} y_raw = [event_id[k] for k in this_epochs.events[:, 2]] sample_weight = [1. / y_raw.count(k) for k in y_raw] y_local = [int(v in local_cond[1]) for v in y_raw] y_global = [int(v in global_cond[1]) for v in y_raw] iter_contrast = [[y_local, y_global], ['local', 'global']] # GAT cv = StratifiedKFold(y=y_raw, n_folds=5) # ensure full stratification for y_fit, names_fit in zip(*iter_contrast): window = 10 window_s = window / epochs.info['sfreq'] # in seconds step = 5. / epochs.info['sfreq'] step = window_s # test_times = 'diagonal' # train_times = dict(length=window_s, step=step) window = len(this_epochs.times) train_times = dict(slices=[range(window)], start=this_epochs.times[0], stop=this_epochs.times[-1], times=[this_epochs.times[0]]) test_times = dict(slices=[[range(window)]], start=this_epochs.times[0], stop=this_epochs.times[-1], times=[this_epochs.times[0]]) kwargs = dict(test_times=test_times, scorer=roc_auc_score, cv=cv, train_times=train_times) # # Fit & Score on a single validation svc = force_predict(force_weight(SVC(kernel='linear', probability=True)), axis=1) # clf_classic = make_pipeline(StandardScaler(), svc) # gat = GAT(n_jobs=-1, clf=clf_classic, **kwargs) # gat.fit(this_epochs, y=np.c_[y_fit, sample_weight]) # gat.score(this_epochs) svc = force_predict(SVC(kernel='linear', probability=True), axis=1) clf_rieman = make_pipeline(reshape_X(window=window), ERPCovariances(estimator='lwf', svd=4), TangentSpace(metric='logeuclid'), svc) gat_rieman = GAT(n_jobs=1, clf=clf_rieman, **kwargs) gat_rieman.fit(this_epochs, y=y_fit) score = gat_rieman.score(this_epochs, y=y_fit) print score # Plot # fig, axes = plt.subplots(3) # # fig = gat.plot_diagonal(label='SVC', show=False, chance=.5, color='b', # # ax=axes[0]) # gat_rieman.plot_diagonal(label='Riemann+SVC', ax=axes[0], color='r', # show=False, chance=False) # # gat.plot(show=False, ax=axes[1], title='SVC') # gat_rieman.plot(show=False, ax=axes[2], title='Riemann') # report.add_figs_to_section(fig, names_fit, subject_name) report.save()