Skip to content

Instantly share code, notes, and snippets.

@kingjr
Created July 20, 2015 21:17

Revisions

  1. kingjr created this gist Jul 20, 2015.
    152 changes: 152 additions & 0 deletions riemann_taj.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,152 @@
    import numpy as np
    import mne
    from mne.decoding import GeneralizationAcrossTime as GAT
    from sklearn.metrics import roc_auc_score
    from sklearn.preprocessing import StandardScaler
    from sklearn.svm import SVC
    from sklearn.pipeline import make_pipeline
    from sklearn.cross_validation import StratifiedKFold
    from meeg_preprocessing.utils import setup_provenance

    import matplotlib.pyplot as plt

    from pyriemann.estimation import ERPCovariances
    from pyriemann.tangentspace import TangentSpace


    report, run_id, results_dir, logger = setup_provenance(
    __file__, results_dir='results')


    conditions = ['LSGS', 'LSGD', 'LDGS', 'LDGD']
    local_cond, global_cond = [
    [['LSGS', 'LSGD'], ['LDGS', 'LDGD']],
    [['LSGS', 'LDGS'], ['LSGD', 'LDGD']]]


    class reshape_X(object):
    def __init__(self, window=1):
    self.window = window # in time sample

    def fit(self, X, y=None):
    return self

    def fit_transform(self, X, y=None):
    return self.transform(X)

    def transform(self, X):
    if X.shape[1] / self.window != X.shape[1] / float(self.window):
    raise ValueError('Wrong window size')
    return X.reshape([X.shape[0], X.shape[1] / self.window, self.window])


    class force_predict(object):
    def __init__(self, clf, mode='predict_proba', axis=0):
    self._mode = mode
    self._axis = axis
    self._clf = clf

    def fit(self, X, y, **kwargs):
    self._clf.fit(X, y, **kwargs)
    self._copyattr()

    def predict(self, X):
    if self._mode == 'predict_proba':
    return self._clf.predict_proba(X)[:, self._axis]
    elif self._mode == 'decision_function':
    distances = self._clf.decision_function(X)
    if len(distances.shape) > 1:
    return distances[:, self._axis]
    else:
    return distances
    else:
    return self._clf.predict(X)

    def get_params(self, deep=True):
    return dict(clf=self._clf, mode=self._mode, axis=self._axis)

    def _copyattr(self):
    for key, value in self._clf.__dict__.iteritems():
    self.__setattr__(key, value)


    class force_weight(object):
    def __init__(self, clf, weights=None):
    self._clf = clf

    def fit(self, X, y):
    return self._clf.fit(X, np.array(y[:, 0], dtype=int),
    sample_weight=np.array(y[:, 1]))

    def predict(self, X):
    return self._clf.predict(X)

    def get_params(self, deep=True):
    return dict(clf=self._clf)

    # weighted probablistic linear classifier
    clf = make_pipeline(StandardScaler(),
    force_predict(force_weight(SVC(
    kernel='linear', probability=True)), axis=1))

    results = list()
    for subject_name in ['TAJ20081223']:
    # Preproc
    epochs = mne.read_epochs('TAJ-epo.fif')
    # this_epochs = epochs[conditions].crop(0.6, None)
    this_epochs = epochs[conditions].crop(0.750, None)

    # Contrast definitions
    event_id = {v: k for k, v in this_epochs.event_id.items()}
    y_raw = [event_id[k] for k in this_epochs.events[:, 2]]
    sample_weight = [1. / y_raw.count(k) for k in y_raw]
    y_local = [int(v in local_cond[1]) for v in y_raw]
    y_global = [int(v in global_cond[1]) for v in y_raw]
    iter_contrast = [[y_local, y_global], ['local', 'global']]

    # GAT
    cv = StratifiedKFold(y=y_raw, n_folds=5) # ensure full stratification
    for y_fit, names_fit in zip(*iter_contrast):
    window = 10
    window_s = window / epochs.info['sfreq'] # in seconds
    step = 5. / epochs.info['sfreq']
    step = window_s
    # test_times = 'diagonal'
    # train_times = dict(length=window_s, step=step)
    window = len(this_epochs.times)
    train_times = dict(slices=[range(window)], start=this_epochs.times[0],
    stop=this_epochs.times[-1],
    times=[this_epochs.times[0]])
    test_times = dict(slices=[[range(window)]], start=this_epochs.times[0],
    stop=this_epochs.times[-1],
    times=[this_epochs.times[0]])

    kwargs = dict(test_times=test_times, scorer=roc_auc_score, cv=cv,
    train_times=train_times)

    # # Fit & Score on a single validation
    svc = force_predict(force_weight(SVC(kernel='linear', probability=True)), axis=1)
    # clf_classic = make_pipeline(StandardScaler(), svc)
    # gat = GAT(n_jobs=-1, clf=clf_classic, **kwargs)
    # gat.fit(this_epochs, y=np.c_[y_fit, sample_weight])
    # gat.score(this_epochs)

    svc = force_predict(SVC(kernel='linear', probability=True), axis=1)
    clf_rieman = make_pipeline(reshape_X(window=window),
    ERPCovariances(estimator='lwf', svd=4),
    TangentSpace(metric='logeuclid'), svc)
    gat_rieman = GAT(n_jobs=1, clf=clf_rieman, **kwargs)
    gat_rieman.fit(this_epochs, y=y_fit)
    score = gat_rieman.score(this_epochs, y=y_fit)
    print score
    # Plot
    # fig, axes = plt.subplots(3)
    # # fig = gat.plot_diagonal(label='SVC', show=False, chance=.5, color='b',
    # # ax=axes[0])
    # gat_rieman.plot_diagonal(label='Riemann+SVC', ax=axes[0], color='r',
    # show=False, chance=False)
    # # gat.plot(show=False, ax=axes[1], title='SVC')
    # gat_rieman.plot(show=False, ax=axes[2], title='Riemann')
    # report.add_figs_to_section(fig, names_fit, subject_name)

    report.save()