{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "nbsphinx": "hidden" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import os\n", "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Trialwise Manual Training Loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "Here, we show the trialwise decoding when you want to write your own training loop. For more simple code with a predefined training loop, see the [Trialwise Decoding Tutorial](./Trialwise_Decoding.html).\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will use a convolutional neural network on the [Physiobank EEG Motor Movement/Imagery Dataset](https://www.physionet.org/physiobank/database/eegmmidb/) to decode two classes:\n", "\n", "1. Executed and imagined opening and closing of both hands\n", "2. Executed and imagined opening and closing of both feet\n", "\n", "
\n", "\n", "We use only one subject (with 90 trials) in this tutorial for demonstration purposes. A more interesting decoding task with many more trials would be to do cross-subject decoding on the same dataset.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Enable logging" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import logging\n", "import importlib\n", "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n", "log = logging.getLogger()\n", "log.setLevel('INFO')\n", "import sys\n", "\n", "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n", " level=logging.INFO, stream=sys.stdout)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can load and preprocess your EEG dataset in any way, Braindecode only expects a 3darray (trials, channels, timesteps) of input signals `X` and a vector of labels `y` later (see below). In this tutorial, we will use the [MNE](https://www.martinos.org/mne/stable/index.html) library to load an EEG motor imagery/motor execution dataset. For a tutorial from MNE using Common Spatial Patterns to decode this data, see [here](http://martinos.org/mne/stable/auto_examples/decoding/plot_decoding_csp_eeg.html). For another library useful for loading EEG data, take a look at [Neo IO](https://pythonhosted.org/neo/io.html)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import mne\n", "from mne.io import concatenate_raws\n", "\n", "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n", "subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)\n", "event_codes = [5,6,9,10,13,14]\n", "\n", "# This will download the files if you don't have them yet,\n", "# and then return the paths to the files.\n", "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n", "\n", "# Load each of the files\n", "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n", " for path in physionet_paths]\n", "\n", "# Concatenate them\n", "raw = concatenate_raws(parts)\n", "\n", "# Find the events in this dataset\n", "events, _ = mne.events_from_annotations(raw)\n", "\n", "# Use only EEG channels\n", "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n", " exclude='bads')\n", "\n", "# Extract trials, only using EEG channels\n", "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n", " baseline=None, preload=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert data to Braindecode format" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Braindecode has a minimalistic ```SignalAndTarget``` class, with attributes `X` for the signal and `y` for the labels. `X` should have these dimensions: trials x channels x timesteps. `y` should have one label per trial." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "# Convert data from volt to millivolt\n", "# Pytorch expects float32 for input and int64 for labels.\n", "X = (epoched.get_data() * 1e6).astype(np.float32)\n", "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the first 40 trials for training and the next 30 trials for validation. The validation accuracies can be used to tune hyperparameters such as learning rate etc. The final 20 trials are split apart so we have a final hold-out evaluation set that is not part of any hyperparameter optimization. As mentioned before, this dataset is dangerously small to get any meaningful results and only used here for quick demonstration purposes." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from braindecode.datautil.signal_target import SignalAndTarget\n", "\n", "train_set = SignalAndTarget(X[:40], y=y[:40])\n", "valid_set = SignalAndTarget(X[40:70], y=y[40:70])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the shallow ConvNet model from [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n", "from torch import nn\n", "from braindecode.torch_ext.util import set_random_seeds\n", "\n", "# Set if you want to use GPU\n", "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n", "cuda = False\n", "set_random_seeds(seed=20170629, cuda=cuda)\n", "n_classes = 2\n", "in_chans = train_set.X.shape[1]\n", "# final_conv_length = auto ensures we only get a single output in the time dimension\n", "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n", " input_time_length=train_set.X.shape[2],\n", " final_conv_length='auto').create_network()\n", "if cuda:\n", " model.cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use [AdamW](https://arxiv.org/abs/1711.05101) to optimize the parameters of our network together with [Cosine Annealing](https://arxiv.org/abs/1608.03983) of the learning rate. We supply some default parameters that we have found to work well for motor decoding, however we strongly encourage you to perform your own hyperparameter optimization using cross validation on your training data." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from braindecode.torch_ext.optimizers import AdamW\n", "from braindecode.torch_ext.schedulers import ScheduledOptimizer, CosineAnnealing\n", "from braindecode.datautil.iterators import get_balanced_batches\n", "from numpy.random import RandomState\n", "rng = RandomState((2018,8,7))\n", "#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model\n", "optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)\n", "# Need to determine number of batch passes per epoch for cosine annealing\n", "n_epochs = 30\n", "n_updates_per_epoch = len(list(get_balanced_batches(len(train_set.X), rng, shuffle=True,\n", " batch_size=30)))\n", "scheduler = CosineAnnealing(n_epochs * n_updates_per_epoch)\n", "# schedule_weight_decay must be True for AdamW\n", "optimizer = ScheduledOptimizer(scheduler, optimizer, schedule_weight_decay=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a conventional mini-batch stochastic gradient descent training loop:\n", "\n", "1. Get randomly shuffled batches of trials\n", "2. Compute outputs, loss and gradients on the batches of trials\n", "3. Update your model\n", "4. After iterating through all batches of your dataset, report some statistics like mean accuracy and mean loss." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0\n", "Train Loss: 0.73690\n", "Train Accuracy: 62.5%\n", "Valid Loss: 1.13041\n", "Valid Accuracy: 53.3%\n", "Epoch 1\n", "Train Loss: 1.17932\n", "Train Accuracy: 57.5%\n", "Valid Loss: 1.35246\n", "Valid Accuracy: 50.0%\n", "Epoch 2\n", "Train Loss: 0.81899\n", "Train Accuracy: 65.0%\n", "Valid Loss: 0.96027\n", "Valid Accuracy: 56.7%\n", "Epoch 3\n", "Train Loss: 0.53547\n", "Train Accuracy: 75.0%\n", "Valid Loss: 0.77725\n", "Valid Accuracy: 66.7%\n", "Epoch 4\n", "Train Loss: 0.30195\n", "Train Accuracy: 85.0%\n", "Valid Loss: 0.60807\n", "Valid Accuracy: 73.3%\n", "Epoch 5\n", "Train Loss: 0.18695\n", "Train Accuracy: 90.0%\n", "Valid Loss: 0.54184\n", "Valid Accuracy: 76.7%\n", "Epoch 6\n", "Train Loss: 0.13377\n", "Train Accuracy: 95.0%\n", "Valid Loss: 0.50131\n", "Valid Accuracy: 80.0%\n", "Epoch 7\n", "Train Loss: 0.11521\n", "Train Accuracy: 95.0%\n", "Valid Loss: 0.47909\n", "Valid Accuracy: 80.0%\n", "Epoch 8\n", "Train Loss: 0.09841\n", "Train Accuracy: 97.5%\n", "Valid Loss: 0.47807\n", "Valid Accuracy: 80.0%\n", "Epoch 9\n", "Train Loss: 0.08487\n", "Train Accuracy: 97.5%\n", "Valid Loss: 0.47951\n", "Valid Accuracy: 80.0%\n", "Epoch 10\n", "Train Loss: 0.07319\n", "Train Accuracy: 97.5%\n", "Valid Loss: 0.48485\n", "Valid Accuracy: 80.0%\n", "Epoch 11\n", "Train Loss: 0.06363\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.49065\n", "Valid Accuracy: 80.0%\n", "Epoch 12\n", "Train Loss: 0.05509\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.49882\n", "Valid Accuracy: 76.7%\n", "Epoch 13\n", "Train Loss: 0.04861\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.50564\n", "Valid Accuracy: 80.0%\n", "Epoch 14\n", "Train Loss: 0.04340\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.51243\n", "Valid Accuracy: 80.0%\n", "Epoch 15\n", "Train Loss: 0.03985\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.51889\n", "Valid Accuracy: 83.3%\n", "Epoch 16\n", "Train Loss: 0.03644\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.52753\n", "Valid Accuracy: 83.3%\n", "Epoch 17\n", "Train Loss: 0.03371\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.53611\n", "Valid Accuracy: 83.3%\n", "Epoch 18\n", "Train Loss: 0.03111\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.54486\n", "Valid Accuracy: 83.3%\n", "Epoch 19\n", "Train Loss: 0.02891\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.55211\n", "Valid Accuracy: 83.3%\n", "Epoch 20\n", "Train Loss: 0.02694\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.55715\n", "Valid Accuracy: 83.3%\n", "Epoch 21\n", "Train Loss: 0.02522\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.56237\n", "Valid Accuracy: 80.0%\n", "Epoch 22\n", "Train Loss: 0.02372\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.56693\n", "Valid Accuracy: 80.0%\n", "Epoch 23\n", "Train Loss: 0.02254\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.57061\n", "Valid Accuracy: 80.0%\n", "Epoch 24\n", "Train Loss: 0.02160\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.57319\n", "Valid Accuracy: 80.0%\n", "Epoch 25\n", "Train Loss: 0.02083\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.57516\n", "Valid Accuracy: 80.0%\n", "Epoch 26\n", "Train Loss: 0.02021\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.57674\n", "Valid Accuracy: 80.0%\n", "Epoch 27\n", "Train Loss: 0.01972\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.57793\n", "Valid Accuracy: 80.0%\n", "Epoch 28\n", "Train Loss: 0.01934\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.57883\n", "Valid Accuracy: 80.0%\n", "Epoch 29\n", "Train Loss: 0.01903\n", "Train Accuracy: 100.0%\n", "Valid Loss: 0.57959\n", "Valid Accuracy: 80.0%\n" ] } ], "source": [ "from braindecode.torch_ext.util import np_to_var, var_to_np\n", "import torch.nn.functional as F\n", "for i_epoch in range(n_epochs):\n", " i_trials_in_batch = get_balanced_batches(len(train_set.X), rng, shuffle=True,\n", " batch_size=30)\n", " # Set model to training mode\n", " model.train()\n", " for i_trials in i_trials_in_batch:\n", " # Have to add empty fourth dimension to X\n", " batch_X = train_set.X[i_trials][:,:,:,None]\n", " batch_y = train_set.y[i_trials]\n", " net_in = np_to_var(batch_X)\n", " if cuda:\n", " net_in = net_in.cuda()\n", " net_target = np_to_var(batch_y)\n", " if cuda:\n", " net_target = net_target.cuda()\n", " # Remove gradients of last backward pass from all parameters \n", " optimizer.zero_grad()\n", " # Compute outputs of the network\n", " outputs = model(net_in)\n", " # Compute the loss\n", " loss = F.nll_loss(outputs, net_target)\n", " # Do the backpropagation\n", " loss.backward()\n", " # Update parameters with the optimizer\n", " optimizer.step()\n", " \n", " # Print some statistics each epoch\n", " model.eval()\n", " print(\"Epoch {:d}\".format(i_epoch))\n", " for setname, dataset in (('Train', train_set), ('Valid', valid_set)):\n", " # Here, we will use the entire dataset at once, which is still possible\n", " # for such smaller datasets. Otherwise we would have to use batches.\n", " net_in = np_to_var(dataset.X[:,:,:,None])\n", " if cuda:\n", " net_in = net_in.cuda()\n", " net_target = np_to_var(dataset.y)\n", " if cuda:\n", " net_target = net_target.cuda()\n", " outputs = model(net_in)\n", " loss = F.nll_loss(outputs, net_target)\n", " print(\"{:6s} Loss: {:.5f}\".format(\n", " setname, float(var_to_np(loss))))\n", " predicted_labels = np.argmax(var_to_np(outputs), axis=1)\n", " accuracy = np.mean(dataset.y == predicted_labels)\n", " print(\"{:6s} Accuracy: {:.1f}%\".format(\n", " setname, accuracy * 100))\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Eventually, we arrive at 80.0% accuracy, so 24 from 30 trials are correctly predicted. In the [Cropped Decoding Tutorial](./Cropped_Decoding.html), we can learn do the same decoding using Cropped Decoding." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.31152\n", "Test Accuracy: 80.0%\n" ] } ], "source": [ "test_set = SignalAndTarget(X[70:], y=y[70:])\n", "\n", "model.eval()\n", "# Here, we will use the entire dataset at once, which is still possible\n", "# for such smaller datasets. Otherwise we would have to use batches.\n", "net_in = np_to_var(test_set.X[:,:,:,None])\n", "if cuda:\n", " net_in = net_in.cuda()\n", "net_target = np_to_var(test_set.y)\n", "if cuda:\n", " net_target = net_target.cuda()\n", "outputs = model(net_in)\n", "loss = F.nll_loss(outputs, net_target)\n", "print(\"Test Loss: {:.5f}\".format(float(var_to_np(loss))))\n", "predicted_labels = np.argmax(var_to_np(outputs), axis=1)\n", "accuracy = np.mean(test_set.y == predicted_labels)\n", "print(\"Test Accuracy: {:.1f}%\".format(accuracy * 100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "If you want to try cross-subject decoding, changing the loading code to the following will perform cross-subject decoding on imagined left vs right hand closing, with 50 training and 5 validation subjects (Warning, might be very slow if you are on CPU):\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import mne\n", "import numpy as np\n", "from mne.io import concatenate_raws\n", "from braindecode.datautil.signal_target import SignalAndTarget\n", "\n", "# First 50 subjects as train\n", "physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]\n", "physionet_paths = np.concatenate(physionet_paths)\n", "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n", " for path in physionet_paths] \n", "\n", "raw = concatenate_raws(parts)\n", "\n", "picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n", " exclude='bads')\n", "\n", "events, _ = mne.events_from_annotations(raw)\n", "\n", "# Read epochs (train will be done only between 1 and 2s)\n", "# Testing will be done with a running classifier\n", "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,\n", " baseline=None, preload=True)\n", "\n", "# 51-55 as validation subjects\n", "physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]\n", "physionet_paths_valid = np.concatenate(physionet_paths_valid)\n", "parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n", " for path in physionet_paths_valid]\n", "raw_valid = concatenate_raws(parts_valid)\n", "\n", "picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,\n", " exclude='bads')\n", "\n", "events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')\n", "\n", "# Read epochs (train will be done only between 1 and 2s)\n", "# Testing will be done with a running classifier\n", "epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,\n", " baseline=None, preload=True)\n", "\n", "train_X = (epoched.get_data() * 1e6).astype(np.float32)\n", "train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n", "valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)\n", "valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n", "train_set = SignalAndTarget(train_X, y=train_y)\n", "valid_set = SignalAndTarget(valid_X, y=valid_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset references\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n", " \n", " Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n", "\n", "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n", "\n", " Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }