Trialwise Decoding

In this example, we will use a convolutional neural network on the Physiobank EEG Motor Movement/Imagery Dataset to decode two classes:

  1. Executed and imagined opening and closing of both hands
  2. Executed and imagined opening and closing of both feet
We use only one subject (with 90 trials) in this tutorial for demonstration purposes. A more interesting decoding task with many more trials would be to do cross-subject decoding on the same dataset.

Enable logging

[2]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

Load data

You can load and preprocess your EEG dataset in any way, Braindecode only expects a 3darray (trials, channels, timesteps) of input signals X and a vector of labels y later (see below). In this tutorial, we will use the MNE library to load an EEG motor imagery/motor execution dataset. For a tutorial from MNE using Common Spatial Patterns to decode this data, see here. For another library useful for loading EEG data, take a look at Neo IO.

[3]:
import mne
from mne.io import concatenate_raws

# 5,6,7,10,13,14 are codes for executed and imagined hands/feet
subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)
event_codes = [5,6,9,10,13,14]
#event_codes = [3,4,5,6,7,8,9,10,11,12,13,14]

# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

# Load each of the files
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')
         for path in physionet_paths]

# Concatenate them
raw = concatenate_raws(parts)

# Find the events in this dataset
events, _ = mne.events_from_annotations(raw)

# Use only EEG channels
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

# Extract trials, only using EEG channels
epoched = mne.Epochs(raw, events, dict(hands_or_left=2, feet_or_right=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,
                baseline=None, preload=True)

Convert data to Braindecode format

Braindecode has a minimalistic SignalAndTarget class, with attributes X for the signal and y for the labels. X should have these dimensions: trials x channels x timesteps. y should have one label per trial.

[4]:
import numpy as np
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1

We use the first 40 trials for training and the next 30 trials for validation. The validation accuracies can be used to tune hyperparameters such as learning rate etc. The final 20 trials are split apart so we have a final hold-out evaluation set that is not part of any hyperparameter optimization. As mentioned before, this dataset is dangerously small to get any meaningful results and only used here for quick demonstration purposes.

[5]:
from braindecode.datautil.signal_target import SignalAndTarget

train_set = SignalAndTarget(X[:40], y=y[:40])
valid_set = SignalAndTarget(X[40:70], y=y[40:70])

Create the model

Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the shallow ConvNet model from Deep learning with convolutional neural networks for EEG decoding and visualization.

[6]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length = auto ensures we only get a single output in the time dimension
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
                        input_time_length=train_set.X.shape[2],
                        final_conv_length='auto')
if cuda:
    model.cuda()

We use AdamW to optimize the parameters of our network together with Cosine Annealing of the learning rate. We supply some default parameters that we have found to work well for motor decoding, however we strongly encourage you to perform your own hyperparameter optimization using cross validation on your training data.

We will now use the Braindecode model class directly to perform the training in a few lines of code. If you instead want to use your own training loop, have a look at the Trialwise Low-Level Tutorial.
[7]:
from braindecode.torch_ext.optimizers import AdamW
import torch.nn.functional as F
#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1,)

Run the training

[8]:
model.fit(train_set.X, train_set.y, epochs=30, batch_size=64, scheduler='cosine',
         validation_data=(valid_set.X, valid_set.y),)
2019-05-27 13:45:17,834 INFO : Run until first stop...
2019-05-27 13:45:18,483 INFO : Epoch 0
2019-05-27 13:45:18,485 INFO : train_loss                5.34665
2019-05-27 13:45:18,485 INFO : valid_loss                5.13145
2019-05-27 13:45:18,486 INFO : train_misclass            0.47500
2019-05-27 13:45:18,486 INFO : valid_misclass            0.46667
2019-05-27 13:45:18,487 INFO : runtime                   0.00000
2019-05-27 13:45:18,488 INFO :
2019-05-27 13:45:19,574 INFO : Time only for training updates: 1.09s
2019-05-27 13:45:20,319 INFO : Epoch 1
2019-05-27 13:45:20,322 INFO : train_loss                1.18747
2019-05-27 13:45:20,324 INFO : valid_loss                1.44545
2019-05-27 13:45:20,325 INFO : train_misclass            0.45000
2019-05-27 13:45:20,327 INFO : valid_misclass            0.53333
2019-05-27 13:45:20,329 INFO : runtime                   1.74098
2019-05-27 13:45:20,330 INFO :
2019-05-27 13:45:21,350 INFO : Time only for training updates: 1.02s
2019-05-27 13:45:22,071 INFO : Epoch 2
2019-05-27 13:45:22,074 INFO : train_loss                0.97473
2019-05-27 13:45:22,076 INFO : valid_loss                1.21793
2019-05-27 13:45:22,077 INFO : train_misclass            0.40000
2019-05-27 13:45:22,079 INFO : valid_misclass            0.50000
2019-05-27 13:45:22,081 INFO : runtime                   1.77702
2019-05-27 13:45:22,082 INFO :
2019-05-27 13:45:23,099 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:23,833 INFO : Epoch 3
2019-05-27 13:45:23,836 INFO : train_loss                0.67758
2019-05-27 13:45:23,838 INFO : valid_loss                0.90153
2019-05-27 13:45:23,839 INFO : train_misclass            0.30000
2019-05-27 13:45:23,841 INFO : valid_misclass            0.40000
2019-05-27 13:45:23,843 INFO : runtime                   1.74868
2019-05-27 13:45:23,844 INFO :
2019-05-27 13:45:24,854 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:25,586 INFO : Epoch 4
2019-05-27 13:45:25,589 INFO : train_loss                0.44457
2019-05-27 13:45:25,591 INFO : valid_loss                0.67692
2019-05-27 13:45:25,592 INFO : train_misclass            0.27500
2019-05-27 13:45:25,594 INFO : valid_misclass            0.30000
2019-05-27 13:45:25,596 INFO : runtime                   1.75507
2019-05-27 13:45:25,597 INFO :
2019-05-27 13:45:26,624 INFO : Time only for training updates: 1.03s
2019-05-27 13:45:27,406 INFO : Epoch 5
2019-05-27 13:45:27,409 INFO : train_loss                0.29853
2019-05-27 13:45:27,410 INFO : valid_loss                0.55171
2019-05-27 13:45:27,412 INFO : train_misclass            0.20000
2019-05-27 13:45:27,414 INFO : valid_misclass            0.20000
2019-05-27 13:45:27,415 INFO : runtime                   1.77113
2019-05-27 13:45:27,417 INFO :
2019-05-27 13:45:28,837 INFO : Time only for training updates: 1.42s
2019-05-27 13:45:29,965 INFO : Epoch 6
2019-05-27 13:45:29,968 INFO : train_loss                0.21160
2019-05-27 13:45:29,970 INFO : valid_loss                0.50791
2019-05-27 13:45:29,971 INFO : train_misclass            0.10000
2019-05-27 13:45:29,973 INFO : valid_misclass            0.16667
2019-05-27 13:45:29,975 INFO : runtime                   2.21169
2019-05-27 13:45:29,976 INFO :
2019-05-27 13:45:31,564 INFO : Time only for training updates: 1.59s
2019-05-27 13:45:32,295 INFO : Epoch 7
2019-05-27 13:45:32,298 INFO : train_loss                0.13291
2019-05-27 13:45:32,300 INFO : valid_loss                0.46993
2019-05-27 13:45:32,302 INFO : train_misclass            0.05000
2019-05-27 13:45:32,303 INFO : valid_misclass            0.16667
2019-05-27 13:45:32,305 INFO : runtime                   2.72866
2019-05-27 13:45:32,307 INFO :
2019-05-27 13:45:33,327 INFO : Time only for training updates: 1.02s
2019-05-27 13:45:34,071 INFO : Epoch 8
2019-05-27 13:45:34,074 INFO : train_loss                0.09791
2019-05-27 13:45:34,076 INFO : valid_loss                0.45241
2019-05-27 13:45:34,077 INFO : train_misclass            0.00000
2019-05-27 13:45:34,079 INFO : valid_misclass            0.16667
2019-05-27 13:45:34,081 INFO : runtime                   1.76224
2019-05-27 13:45:34,082 INFO :
2019-05-27 13:45:35,169 INFO : Time only for training updates: 1.08s
2019-05-27 13:45:35,935 INFO : Epoch 9
2019-05-27 13:45:35,938 INFO : train_loss                0.07745
2019-05-27 13:45:35,940 INFO : valid_loss                0.45301
2019-05-27 13:45:35,942 INFO : train_misclass            0.00000
2019-05-27 13:45:35,943 INFO : valid_misclass            0.20000
2019-05-27 13:45:35,945 INFO : runtime                   1.84197
2019-05-27 13:45:35,947 INFO :
2019-05-27 13:45:36,983 INFO : Time only for training updates: 1.03s
2019-05-27 13:45:37,707 INFO : Epoch 10
2019-05-27 13:45:37,710 INFO : train_loss                0.06647
2019-05-27 13:45:37,711 INFO : valid_loss                0.45929
2019-05-27 13:45:37,713 INFO : train_misclass            0.00000
2019-05-27 13:45:37,715 INFO : valid_misclass            0.20000
2019-05-27 13:45:37,716 INFO : runtime                   1.81326
2019-05-27 13:45:37,718 INFO :
2019-05-27 13:45:38,732 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:39,462 INFO : Epoch 11
2019-05-27 13:45:39,465 INFO : train_loss                0.05948
2019-05-27 13:45:39,467 INFO : valid_loss                0.47226
2019-05-27 13:45:39,469 INFO : train_misclass            0.00000
2019-05-27 13:45:39,470 INFO : valid_misclass            0.20000
2019-05-27 13:45:39,472 INFO : runtime                   1.74926
2019-05-27 13:45:39,473 INFO :
2019-05-27 13:45:40,488 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:41,211 INFO : Epoch 12
2019-05-27 13:45:41,214 INFO : train_loss                0.05326
2019-05-27 13:45:41,215 INFO : valid_loss                0.48799
2019-05-27 13:45:41,217 INFO : train_misclass            0.00000
2019-05-27 13:45:41,219 INFO : valid_misclass            0.20000
2019-05-27 13:45:41,220 INFO : runtime                   1.75707
2019-05-27 13:45:41,222 INFO :
2019-05-27 13:45:42,236 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:42,997 INFO : Epoch 13
2019-05-27 13:45:43,000 INFO : train_loss                0.04709
2019-05-27 13:45:43,001 INFO : valid_loss                0.50407
2019-05-27 13:45:43,003 INFO : train_misclass            0.00000
2019-05-27 13:45:43,005 INFO : valid_misclass            0.20000
2019-05-27 13:45:43,007 INFO : runtime                   1.74734
2019-05-27 13:45:43,009 INFO :
2019-05-27 13:45:44,112 INFO : Time only for training updates: 1.10s
2019-05-27 13:45:44,842 INFO : Epoch 14
2019-05-27 13:45:44,844 INFO : train_loss                0.04247
2019-05-27 13:45:44,846 INFO : valid_loss                0.50820
2019-05-27 13:45:44,848 INFO : train_misclass            0.00000
2019-05-27 13:45:44,849 INFO : valid_misclass            0.20000
2019-05-27 13:45:44,851 INFO : runtime                   1.87550
2019-05-27 13:45:44,853 INFO :
2019-05-27 13:45:45,869 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:46,605 INFO : Epoch 15
2019-05-27 13:45:46,608 INFO : train_loss                0.03919
2019-05-27 13:45:46,610 INFO : valid_loss                0.51205
2019-05-27 13:45:46,611 INFO : train_misclass            0.00000
2019-05-27 13:45:46,613 INFO : valid_misclass            0.20000
2019-05-27 13:45:46,615 INFO : runtime                   1.75688
2019-05-27 13:45:46,616 INFO :
2019-05-27 13:45:47,633 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:48,365 INFO : Epoch 16
2019-05-27 13:45:48,368 INFO : train_loss                0.03620
2019-05-27 13:45:48,370 INFO : valid_loss                0.51271
2019-05-27 13:45:48,371 INFO : train_misclass            0.00000
2019-05-27 13:45:48,373 INFO : valid_misclass            0.20000
2019-05-27 13:45:48,374 INFO : runtime                   1.76364
2019-05-27 13:45:48,376 INFO :
2019-05-27 13:45:49,392 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:50,111 INFO : Epoch 17
2019-05-27 13:45:50,114 INFO : train_loss                0.03297
2019-05-27 13:45:50,116 INFO : valid_loss                0.50701
2019-05-27 13:45:50,117 INFO : train_misclass            0.00000
2019-05-27 13:45:50,119 INFO : valid_misclass            0.16667
2019-05-27 13:45:50,121 INFO : runtime                   1.75996
2019-05-27 13:45:50,123 INFO :
2019-05-27 13:45:51,134 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:51,858 INFO : Epoch 18
2019-05-27 13:45:51,860 INFO : train_loss                0.03033
2019-05-27 13:45:51,862 INFO : valid_loss                0.50634
2019-05-27 13:45:51,864 INFO : train_misclass            0.00000
2019-05-27 13:45:51,865 INFO : valid_misclass            0.16667
2019-05-27 13:45:51,867 INFO : runtime                   1.74143
2019-05-27 13:45:51,869 INFO :
2019-05-27 13:45:52,946 INFO : Time only for training updates: 1.08s
2019-05-27 13:45:53,671 INFO : Epoch 19
2019-05-27 13:45:53,674 INFO : train_loss                0.02815
2019-05-27 13:45:53,675 INFO : valid_loss                0.50780
2019-05-27 13:45:53,677 INFO : train_misclass            0.00000
2019-05-27 13:45:53,679 INFO : valid_misclass            0.16667
2019-05-27 13:45:53,680 INFO : runtime                   1.81164
2019-05-27 13:45:53,682 INFO :
2019-05-27 13:45:54,690 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:55,414 INFO : Epoch 20
2019-05-27 13:45:55,417 INFO : train_loss                0.02614
2019-05-27 13:45:55,418 INFO : valid_loss                0.51158
2019-05-27 13:45:55,420 INFO : train_misclass            0.00000
2019-05-27 13:45:55,422 INFO : valid_misclass            0.16667
2019-05-27 13:45:55,424 INFO : runtime                   1.74491
2019-05-27 13:45:55,425 INFO :
2019-05-27 13:45:56,435 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:57,163 INFO : Epoch 21
2019-05-27 13:45:57,166 INFO : train_loss                0.02446
2019-05-27 13:45:57,168 INFO : valid_loss                0.51534
2019-05-27 13:45:57,169 INFO : train_misclass            0.00000
2019-05-27 13:45:57,171 INFO : valid_misclass            0.16667
2019-05-27 13:45:57,173 INFO : runtime                   1.74495
2019-05-27 13:45:57,174 INFO :
2019-05-27 13:45:58,181 INFO : Time only for training updates: 1.01s
2019-05-27 13:45:58,907 INFO : Epoch 22
2019-05-27 13:45:58,910 INFO : train_loss                0.02319
2019-05-27 13:45:58,911 INFO : valid_loss                0.51801
2019-05-27 13:45:58,913 INFO : train_misclass            0.00000
2019-05-27 13:45:58,915 INFO : valid_misclass            0.16667
2019-05-27 13:45:58,916 INFO : runtime                   1.74666
2019-05-27 13:45:58,918 INFO :
2019-05-27 13:45:59,961 INFO : Time only for training updates: 1.04s
2019-05-27 13:46:00,709 INFO : Epoch 23
2019-05-27 13:46:00,711 INFO : train_loss                0.02210
2019-05-27 13:46:00,713 INFO : valid_loss                0.52047
2019-05-27 13:46:00,715 INFO : train_misclass            0.00000
2019-05-27 13:46:00,716 INFO : valid_misclass            0.16667
2019-05-27 13:46:00,718 INFO : runtime                   1.77873
2019-05-27 13:46:00,720 INFO :
2019-05-27 13:46:01,750 INFO : Time only for training updates: 1.03s
2019-05-27 13:46:02,475 INFO : Epoch 24
2019-05-27 13:46:02,477 INFO : train_loss                0.02131
2019-05-27 13:46:02,479 INFO : valid_loss                0.52308
2019-05-27 13:46:02,481 INFO : train_misclass            0.00000
2019-05-27 13:46:02,482 INFO : valid_misclass            0.16667
2019-05-27 13:46:02,484 INFO : runtime                   1.78904
2019-05-27 13:46:02,486 INFO :
2019-05-27 13:46:03,495 INFO : Time only for training updates: 1.01s
2019-05-27 13:46:04,222 INFO : Epoch 25
2019-05-27 13:46:04,225 INFO : train_loss                0.02070
2019-05-27 13:46:04,227 INFO : valid_loss                0.52473
2019-05-27 13:46:04,228 INFO : train_misclass            0.00000
2019-05-27 13:46:04,230 INFO : valid_misclass            0.16667
2019-05-27 13:46:04,232 INFO : runtime                   1.74557
2019-05-27 13:46:04,233 INFO :
2019-05-27 13:46:05,245 INFO : Time only for training updates: 1.01s
2019-05-27 13:46:05,973 INFO : Epoch 26
2019-05-27 13:46:05,976 INFO : train_loss                0.02029
2019-05-27 13:46:05,978 INFO : valid_loss                0.52578
2019-05-27 13:46:05,979 INFO : train_misclass            0.00000
2019-05-27 13:46:05,981 INFO : valid_misclass            0.16667
2019-05-27 13:46:05,983 INFO : runtime                   1.74980
2019-05-27 13:46:05,984 INFO :
2019-05-27 13:46:07,009 INFO : Time only for training updates: 1.02s
2019-05-27 13:46:07,731 INFO : Epoch 27
2019-05-27 13:46:07,734 INFO : train_loss                0.02002
2019-05-27 13:46:07,736 INFO : valid_loss                0.52576
2019-05-27 13:46:07,737 INFO : train_misclass            0.00000
2019-05-27 13:46:07,739 INFO : valid_misclass            0.16667
2019-05-27 13:46:07,741 INFO : runtime                   1.76430
2019-05-27 13:46:07,742 INFO :
2019-05-27 13:46:08,749 INFO : Time only for training updates: 1.01s
2019-05-27 13:46:09,489 INFO : Epoch 28
2019-05-27 13:46:09,491 INFO : train_loss                0.01986
2019-05-27 13:46:09,493 INFO : valid_loss                0.52516
2019-05-27 13:46:09,495 INFO : train_misclass            0.00000
2019-05-27 13:46:09,496 INFO : valid_misclass            0.16667
2019-05-27 13:46:09,498 INFO : runtime                   1.73994
2019-05-27 13:46:09,500 INFO :
2019-05-27 13:46:10,512 INFO : Time only for training updates: 1.01s
2019-05-27 13:46:11,233 INFO : Epoch 29
2019-05-27 13:46:11,236 INFO : train_loss                0.01978
2019-05-27 13:46:11,238 INFO : valid_loss                0.52423
2019-05-27 13:46:11,239 INFO : train_misclass            0.00000
2019-05-27 13:46:11,241 INFO : valid_misclass            0.16667
2019-05-27 13:46:11,243 INFO : runtime                   1.76238
2019-05-27 13:46:11,244 INFO :
2019-05-27 13:46:12,273 INFO : Time only for training updates: 1.03s
2019-05-27 13:46:12,997 INFO : Epoch 30
2019-05-27 13:46:13,000 INFO : train_loss                0.01976
2019-05-27 13:46:13,002 INFO : valid_loss                0.52316
2019-05-27 13:46:13,004 INFO : train_misclass            0.00000
2019-05-27 13:46:13,005 INFO : valid_misclass            0.16667
2019-05-27 13:46:13,007 INFO : runtime                   1.76154
2019-05-27 13:46:13,009 INFO :
[8]:
<braindecode.experiments.experiment.Experiment at 0x7fb488906e80>

The monitored values are also stored into a pandas dataframe:

[9]:
model.epochs_df
[9]:
train_loss valid_loss train_misclass valid_misclass runtime
0 5.346649 5.131454 0.475 0.466667 0.000000
1 1.187472 1.445452 0.450 0.533333 1.740984
2 0.974734 1.217927 0.400 0.500000 1.777023
3 0.677576 0.901526 0.300 0.400000 1.748683
4 0.444574 0.676924 0.275 0.300000 1.755075
5 0.298528 0.551715 0.200 0.200000 1.771130
6 0.211601 0.507913 0.100 0.166667 2.211694
7 0.132909 0.469927 0.050 0.166667 2.728657
8 0.097908 0.452408 0.000 0.166667 1.762245
9 0.077446 0.453012 0.000 0.200000 1.841966
10 0.066468 0.459292 0.000 0.200000 1.813261
11 0.059477 0.472259 0.000 0.200000 1.749265
12 0.053263 0.487993 0.000 0.200000 1.757071
13 0.047088 0.504074 0.000 0.200000 1.747344
14 0.042468 0.508204 0.000 0.200000 1.875502
15 0.039186 0.512052 0.000 0.200000 1.756878
16 0.036201 0.512707 0.000 0.200000 1.763641
17 0.032971 0.507012 0.000 0.166667 1.759965
18 0.030331 0.506342 0.000 0.166667 1.741430
19 0.028149 0.507796 0.000 0.166667 1.811639
20 0.026143 0.511578 0.000 0.166667 1.744909
21 0.024460 0.515335 0.000 0.166667 1.744946
22 0.023193 0.518014 0.000 0.166667 1.746660
23 0.022102 0.520473 0.000 0.166667 1.778730
24 0.021306 0.523084 0.000 0.166667 1.789039
25 0.020703 0.524732 0.000 0.166667 1.745571
26 0.020286 0.525779 0.000 0.166667 1.749797
27 0.020023 0.525761 0.000 0.166667 1.764305
28 0.019861 0.525160 0.000 0.166667 1.739945
29 0.019783 0.524232 0.000 0.166667 1.762381
30 0.019764 0.523157 0.000 0.166667 1.761538

Eventually, we arrive at 83.4% accuracy, so 25 from 30 trials are correctly predicted. In the Cropped Decoding Tutorial, we can learn how to achieve higher accuracies using cropped training.

Evaluation

Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:

[10]:
test_set = SignalAndTarget(X[70:], y=y[70:])

model.evaluate(test_set.X, test_set.y)
[10]:
{'loss': 0.43049120903015137,
 'misclass': 0.19999999999999996,
 'runtime': 0.00031256675720214844}

We can also retrieve predicted labels per trial as such:

[11]:
model.predict_classes(test_set.X)
[11]:
array([1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0])

We can also retrieve the raw network outputs per trial as such:

Note these are log-softmax outputs, so to get probabilities one would have to exponentiate them using th.exp.
[12]:
model.predict_outs(test_set.X)
[12]:
array([[-3.108049  , -0.0457173 ],
       [-0.18737975, -1.7668451 ],
       [-3.550216  , -0.02913891],
       [-0.00889281, -4.7269597 ],
       [-0.03029956, -3.5117326 ],
       [-0.00847233, -4.775178  ],
       [-4.006974  , -0.01835574],
       [-0.4073605 , -1.0948324 ],
       [-0.02217731, -3.819753  ],
       [-0.22672828, -1.5952262 ],
       [-3.5868406 , -0.02807612],
       [-1.3834822 , -0.28862125],
       [-0.32644472, -1.2782807 ],
       [-1.3229265 , -0.30972955],
       [-0.08954818, -2.4574184 ],
       [-0.0186951 , -3.9888284 ],
       [-0.09142652, -2.4375842 ],
       [-0.24392553, -1.5303771 ],
       [-0.03591001, -3.3446407 ],
       [-0.16686489, -1.8728433 ]], dtype=float32)
If you want to try cross-subject decoding, changing the loading code to the following will perform cross-subject decoding on imagined left vs right hand closing, with 50 training and 5 validation subjects (Warning, might be very slow if you are on CPU):
[ ]:
import mne
import numpy as np
from mne.io import concatenate_raws
from braindecode.datautil.signal_target import SignalAndTarget

# First 50 subjects as train
physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]
physionet_paths = np.concatenate(physionet_paths)
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
         for path in physionet_paths]

raw = concatenate_raws(parts)

picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

# Find the events in this dataset
events, _ = mne.events_from_annotations(raw)

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,
                baseline=None, preload=True)

# 51-55 as validation subjects
physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]
physionet_paths_valid = np.concatenate(physionet_paths_valid)
parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
         for path in physionet_paths_valid]
raw_valid = concatenate_raws(parts_valid)

picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,
                baseline=None, preload=True)

train_X = (epoched.get_data() * 1e6).astype(np.float32)
train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)
valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
train_set = SignalAndTarget(train_X, y=train_y)
valid_set = SignalAndTarget(valid_X, y=valid_y)

Dataset references

This dataset was created and contributed to PhysioNet by the developers of the BCI2000 instrumentation system, which they used in making these recordings. The system is described in:

Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.

PhysioBank is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:

Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220.