Cropped Manual Training Loop

Here, we show the cropped decoding when you want to write your own training loop. For more simple code with a predefined training loop and an explanation of cropped decoding in general, see the Cropped Decoding Tutorial.

Most of the code for cropped decoding is identical to the Trialwise Manual Training Loop Tutorial, differences are explained in the text.

Load data

[2]:
import mne
from mne.io import concatenate_raws

# 5,6,7,10,13,14 are codes for executed and imagined hands/feet
subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)
event_codes = [5,6,9,10,13,14]

# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

# Load each of the files
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')
         for path in physionet_paths]

# Concatenate them
raw = concatenate_raws(parts)

# Find the events in this dataset
events, _ = mne.events_from_annotations(raw)

# Use only EEG channels
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

# Extract trials, only using EEG channels
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,
                baseline=None, preload=True)

Convert data to Braindecode format

[3]:
import numpy as np
from braindecode.datautil.signal_target import SignalAndTarget
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1

train_set = SignalAndTarget(X[:40], y=y[:40])
valid_set = SignalAndTarget(X[40:70], y=y[40:70])

Create the model

For cropped decoding, we now transform the model into a model that outputs a dense time series of predictions. For this, we manually set the length of the final convolution layer to some length that makes the receptive field of the ConvNet smaller than the number of samples in a trial. Also, we use to_dense_prediction_model, which removes the strides in the ConvNet and instead uses dilated convolutions to get a dense output (see Multi-Scale Context Aggregation by Dilated Convolutions and our paper Deep learning with convolutional neural networks for EEG decoding and visualization Section 2.5.4 for some background on this).

[4]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 450
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,
                        final_conv_length=12).create_network()
to_dense_prediction_model(model)

if cuda:
    model.cuda()


Create cropped iterator

For extracting crops from the trials, Braindecode provides the CropsFromTrialsIterator? class. This class needs to know the input time length of the inputs you put into the network and the number of predictions that the ConvNet will output per input. You can determine the number of predictions by passing dummy data through the ConvNet:

[5]:
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))
187 predictions per input/trial
[6]:
from braindecode.datautil.iterators import CropsFromTrialsIterator
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
                                  n_preds_per_input=n_preds_per_input)

The iterator has the method get_batches, which can be used to get randomly shuffled training batches with shuffle=True or ordered batches (i.e. first from trial 1, then from trial 2, etc.) with shuffle=False. Additionally, Braindecode provides the compute_preds_per_trial_for_set method, which accepts predictions from the ordered batches and returns predictions per trial. It removes any overlapping predictions, which occur if the number of predictions per input is not a divisor of the number of samples in a trial.

These methods can also work with trials of different lengths! For different-length trials, set X to be a list of 2d-arrays instead of a 3d-array.

We now can set the optimizer, since we can compute the number of batches per epoch using the iterator.

[7]:
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.schedulers import ScheduledOptimizer, CosineAnnealing
from braindecode.datautil.iterators import get_balanced_batches
from numpy.random import RandomState
rng = RandomState((2018,8,7))
#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
# Need to determine number of batch passes per epoch for cosine annealing
n_epochs = 30
n_updates_per_epoch = len([None for b in iterator.get_batches(train_set, True)])
scheduler = CosineAnnealing(n_epochs * n_updates_per_epoch)
# schedule_weight_decay must be True for AdamW
optimizer = ScheduledOptimizer(scheduler, optimizer, schedule_weight_decay=True)

Training loop

The code below uses both the cropped iterator and the compute_preds_per_trial_from_crops function to train and evaluate the network.

[8]:
from braindecode.torch_ext.util import np_to_var, var_to_np
import torch.nn.functional as F
from numpy.random import RandomState
import torch as th
from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
rng = RandomState((2017,6,30))
for i_epoch in range(20):
    # Set model to training mode
    model.train()
    for batch_X, batch_y in iterator.get_batches(train_set, shuffle=True):
        net_in = np_to_var(batch_X)
        if cuda:
            net_in = net_in.cuda()
        net_target = np_to_var(batch_y)
        if cuda:
            net_target = net_target.cuda()
        # Remove gradients of last backward pass from all parameters
        optimizer.zero_grad()
        outputs = model(net_in)
        # Mean predictions across trial
        # Note that this will give identical gradients to computing
        # a per-prediction loss (at least for the combination of log softmax activation
        # and negative log likelihood loss which we are using here)
        outputs = th.mean(outputs, dim=2, keepdim=False)
        loss = F.nll_loss(outputs, net_target)
        loss.backward()
        optimizer.step()

    # Print some statistics each epoch
    model.eval()
    print("Epoch {:d}".format(i_epoch))
    for setname, dataset in (('Train', train_set),('Valid', valid_set)):
        # Collect all predictions and losses
        all_preds = []
        all_losses = []
        batch_sizes = []
        for batch_X, batch_y in iterator.get_batches(dataset, shuffle=False):
            net_in = np_to_var(batch_X)
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(batch_y)
            if cuda:
                net_target = net_target.cuda()
            outputs = model(net_in)
            all_preds.append(var_to_np(outputs))
            outputs = th.mean(outputs, dim=2, keepdim=False)
            loss = F.nll_loss(outputs, net_target)
            loss = float(var_to_np(loss))
            all_losses.append(loss)
            batch_sizes.append(len(batch_X))
        # Compute mean per-input loss
        loss = np.mean(np.array(all_losses) * np.array(batch_sizes) /
                       np.mean(batch_sizes))
        print("{:6s} Loss: {:.5f}".format(setname, loss))
        # Assign the predictions to the trials
        preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
                                                          input_time_length,
                                                          dataset.X)
        # preds per trial are now trials x classes x timesteps/predictions
        # Now mean across timesteps for each trial to get per-trial predictions
        meaned_preds_per_trial = np.array([np.mean(p, axis=1) for p in preds_per_trial])
        predicted_labels = np.argmax(meaned_preds_per_trial, axis=1)
        accuracy = np.mean(predicted_labels == dataset.y)
        print("{:6s} Accuracy: {:.1f}%".format(
            setname, accuracy * 100))
Epoch 0
Train  Loss: 3.82019
Train  Accuracy: 50.0%
Valid  Loss: 3.16695
Valid  Accuracy: 46.7%
Epoch 1
Train  Loss: 1.88180
Train  Accuracy: 50.0%
Valid  Loss: 1.52698
Valid  Accuracy: 50.0%
Epoch 2
Train  Loss: 1.01281
Train  Accuracy: 60.0%
Valid  Loss: 0.95791
Valid  Accuracy: 56.7%
Epoch 3
Train  Loss: 0.72270
Train  Accuracy: 67.5%
Valid  Loss: 0.85744
Valid  Accuracy: 56.7%
Epoch 4
Train  Loss: 0.56512
Train  Accuracy: 72.5%
Valid  Loss: 0.79851
Valid  Accuracy: 63.3%
Epoch 5
Train  Loss: 0.34636
Train  Accuracy: 82.5%
Valid  Loss: 0.61648
Valid  Accuracy: 73.3%
Epoch 6
Train  Loss: 0.25957
Train  Accuracy: 90.0%
Valid  Loss: 0.55787
Valid  Accuracy: 83.3%
Epoch 7
Train  Loss: 0.20769
Train  Accuracy: 95.0%
Valid  Loss: 0.51277
Valid  Accuracy: 83.3%
Epoch 8
Train  Loss: 0.17177
Train  Accuracy: 97.5%
Valid  Loss: 0.45372
Valid  Accuracy: 86.7%
Epoch 9
Train  Loss: 0.14081
Train  Accuracy: 97.5%
Valid  Loss: 0.40558
Valid  Accuracy: 86.7%
Epoch 10
Train  Loss: 0.10214
Train  Accuracy: 100.0%
Valid  Loss: 0.36364
Valid  Accuracy: 86.7%
Epoch 11
Train  Loss: 0.07835
Train  Accuracy: 100.0%
Valid  Loss: 0.35407
Valid  Accuracy: 90.0%
Epoch 12
Train  Loss: 0.07564
Train  Accuracy: 100.0%
Valid  Loss: 0.36432
Valid  Accuracy: 90.0%
Epoch 13
Train  Loss: 0.07691
Train  Accuracy: 100.0%
Valid  Loss: 0.36904
Valid  Accuracy: 90.0%
Epoch 14
Train  Loss: 0.06718
Train  Accuracy: 100.0%
Valid  Loss: 0.35551
Valid  Accuracy: 90.0%
Epoch 15
Train  Loss: 0.05421
Train  Accuracy: 100.0%
Valid  Loss: 0.33592
Valid  Accuracy: 90.0%
Epoch 16
Train  Loss: 0.04269
Train  Accuracy: 100.0%
Valid  Loss: 0.31994
Valid  Accuracy: 90.0%
Epoch 17
Train  Loss: 0.03827
Train  Accuracy: 100.0%
Valid  Loss: 0.31341
Valid  Accuracy: 90.0%
Epoch 18
Train  Loss: 0.03477
Train  Accuracy: 100.0%
Valid  Loss: 0.30831
Valid  Accuracy: 86.7%
Epoch 19
Train  Loss: 0.03244
Train  Accuracy: 100.0%
Valid  Loss: 0.30248
Valid  Accuracy: 90.0%

Eventually, we arrive at 90.0% accuracy, so 27 from 30 trials are correctly predicted, 5 more than for the trialwise decoding method.

Evaluation

Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:

[9]:
test_set = SignalAndTarget(X[70:], y=y[70:])

model.eval()
# Collect all predictions and losses
all_preds = []
all_losses = []
batch_sizes = []
for batch_X, batch_y in iterator.get_batches(test_set, shuffle=False):
    net_in = np_to_var(batch_X)
    if cuda:
        net_in = net_in.cuda()
    net_target = np_to_var(batch_y)
    if cuda:
        net_target = net_target.cuda()
    outputs = model(net_in)
    all_preds.append(var_to_np(outputs))
    outputs = th.mean(outputs, dim=2, keepdim=False)
    loss = F.nll_loss(outputs, net_target)
    loss = float(var_to_np(loss))
    all_losses.append(loss)
    batch_sizes.append(len(batch_X))
# Compute mean per-input loss
loss = np.mean(np.array(all_losses) * np.array(batch_sizes) /
               np.mean(batch_sizes))
print("Test Loss: {:.5f}".format(loss))
# Assign the predictions to the trials
preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
                                                  input_time_length,
                                                  test_set.X)
# preds per trial are now trials x classes x timesteps/predictions
# Now mean across timesteps for each trial to get per-trial predictions
meaned_preds_per_trial = np.array([np.mean(p, axis=1) for p in preds_per_trial])
predicted_labels = np.argmax(meaned_preds_per_trial, axis=1)
accuracy = np.mean(predicted_labels == test_set.y)
print("Test Accuracy: {:.1f}%".format(accuracy * 100))
Test Loss: 0.42250
Test Accuracy: 90.0%

Dataset references

This dataset was created and contributed to PhysioNet by the developers of the BCI2000 instrumentation system, which they used in making these recordings. The system is described in:

Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.

PhysioBank is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:

Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220.