{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "nbsphinx": "hidden" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import os\n", "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Trialwise Decoding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will use a convolutional neural network on the [Physiobank EEG Motor Movement/Imagery Dataset](https://www.physionet.org/physiobank/database/eegmmidb/) to decode two classes:\n", "\n", "1. Executed and imagined opening and closing of both hands\n", "2. Executed and imagined opening and closing of both feet\n", "\n", "
\n", "\n", "We use only one subject (with 90 trials) in this tutorial for demonstration purposes. A more interesting decoding task with many more trials would be to do cross-subject decoding on the same dataset.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Enable logging" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import logging\n", "import importlib\n", "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n", "log = logging.getLogger()\n", "log.setLevel('INFO')\n", "import sys\n", "\n", "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n", " level=logging.INFO, stream=sys.stdout)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can load and preprocess your EEG dataset in any way, Braindecode only expects a 3darray (trials, channels, timesteps) of input signals `X` and a vector of labels `y` later (see below). In this tutorial, we will use the [MNE](https://www.martinos.org/mne/stable/index.html) library to load an EEG motor imagery/motor execution dataset. For a tutorial from MNE using Common Spatial Patterns to decode this data, see [here](http://martinos.org/mne/stable/auto_examples/decoding/plot_decoding_csp_eeg.html). For another library useful for loading EEG data, take a look at [Neo IO](https://pythonhosted.org/neo/io.html)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import mne\n", "from mne.io import concatenate_raws\n", "\n", "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n", "subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)\n", "event_codes = [5,6,9,10,13,14]\n", "#event_codes = [3,4,5,6,7,8,9,10,11,12,13,14]\n", "\n", "# This will download the files if you don't have them yet,\n", "# and then return the paths to the files.\n", "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n", "\n", "# Load each of the files\n", "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n", " for path in physionet_paths]\n", "\n", "# Concatenate them\n", "raw = concatenate_raws(parts)\n", "\n", "# Find the events in this dataset\n", "events, _ = mne.events_from_annotations(raw)\n", "\n", "# Use only EEG channels\n", "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n", " exclude='bads')\n", "\n", "# Extract trials, only using EEG channels\n", "epoched = mne.Epochs(raw, events, dict(hands_or_left=2, feet_or_right=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n", " baseline=None, preload=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert data to Braindecode format" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Braindecode has a minimalistic ```SignalAndTarget``` class, with attributes `X` for the signal and `y` for the labels. `X` should have these dimensions: trials x channels x timesteps. `y` should have one label per trial." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "# Convert data from volt to millivolt\n", "# Pytorch expects float32 for input and int64 for labels.\n", "X = (epoched.get_data() * 1e6).astype(np.float32)\n", "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the first 40 trials for training and the next 30 trials for validation. The validation accuracies can be used to tune hyperparameters such as learning rate etc. The final 20 trials are split apart so we have a final hold-out evaluation set that is not part of any hyperparameter optimization. As mentioned before, this dataset is dangerously small to get any meaningful results and only used here for quick demonstration purposes." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from braindecode.datautil.signal_target import SignalAndTarget\n", "\n", "train_set = SignalAndTarget(X[:40], y=y[:40])\n", "valid_set = SignalAndTarget(X[40:70], y=y[40:70])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the shallow ConvNet model from [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n", "from torch import nn\n", "from braindecode.torch_ext.util import set_random_seeds\n", "\n", "# Set if you want to use GPU\n", "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n", "cuda = False\n", "set_random_seeds(seed=20170629, cuda=cuda)\n", "n_classes = 2\n", "in_chans = train_set.X.shape[1]\n", "# final_conv_length = auto ensures we only get a single output in the time dimension\n", "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n", " input_time_length=train_set.X.shape[2],\n", " final_conv_length='auto')\n", "if cuda:\n", " model.cuda()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use [AdamW](https://arxiv.org/abs/1711.05101) to optimize the parameters of our network together with [Cosine Annealing](https://arxiv.org/abs/1608.03983) of the learning rate. We supply some default parameters that we have found to work well for motor decoding, however we strongly encourage you to perform your own hyperparameter optimization using cross validation on your training data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "We will now use the Braindecode model class directly to perform the training in a few lines of code. If you instead want to use your own training loop, have a look at the [Trialwise Low-Level Tutorial](./TrialWise_LowLevel.html).\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from braindecode.torch_ext.optimizers import AdamW\n", "import torch.nn.functional as F\n", "#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model\n", "optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)\n", "model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1,)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2019-05-27 13:45:17,834 INFO : Run until first stop...\n", "2019-05-27 13:45:18,483 INFO : Epoch 0\n", "2019-05-27 13:45:18,485 INFO : train_loss 5.34665\n", "2019-05-27 13:45:18,485 INFO : valid_loss 5.13145\n", "2019-05-27 13:45:18,486 INFO : train_misclass 0.47500\n", "2019-05-27 13:45:18,486 INFO : valid_misclass 0.46667\n", "2019-05-27 13:45:18,487 INFO : runtime 0.00000\n", "2019-05-27 13:45:18,488 INFO : \n", "2019-05-27 13:45:19,574 INFO : Time only for training updates: 1.09s\n", "2019-05-27 13:45:20,319 INFO : Epoch 1\n", "2019-05-27 13:45:20,322 INFO : train_loss 1.18747\n", "2019-05-27 13:45:20,324 INFO : valid_loss 1.44545\n", "2019-05-27 13:45:20,325 INFO : train_misclass 0.45000\n", "2019-05-27 13:45:20,327 INFO : valid_misclass 0.53333\n", "2019-05-27 13:45:20,329 INFO : runtime 1.74098\n", "2019-05-27 13:45:20,330 INFO : \n", "2019-05-27 13:45:21,350 INFO : Time only for training updates: 1.02s\n", "2019-05-27 13:45:22,071 INFO : Epoch 2\n", "2019-05-27 13:45:22,074 INFO : train_loss 0.97473\n", "2019-05-27 13:45:22,076 INFO : valid_loss 1.21793\n", "2019-05-27 13:45:22,077 INFO : train_misclass 0.40000\n", "2019-05-27 13:45:22,079 INFO : valid_misclass 0.50000\n", "2019-05-27 13:45:22,081 INFO : runtime 1.77702\n", "2019-05-27 13:45:22,082 INFO : \n", "2019-05-27 13:45:23,099 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:23,833 INFO : Epoch 3\n", "2019-05-27 13:45:23,836 INFO : train_loss 0.67758\n", "2019-05-27 13:45:23,838 INFO : valid_loss 0.90153\n", "2019-05-27 13:45:23,839 INFO : train_misclass 0.30000\n", "2019-05-27 13:45:23,841 INFO : valid_misclass 0.40000\n", "2019-05-27 13:45:23,843 INFO : runtime 1.74868\n", "2019-05-27 13:45:23,844 INFO : \n", "2019-05-27 13:45:24,854 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:25,586 INFO : Epoch 4\n", "2019-05-27 13:45:25,589 INFO : train_loss 0.44457\n", "2019-05-27 13:45:25,591 INFO : valid_loss 0.67692\n", "2019-05-27 13:45:25,592 INFO : train_misclass 0.27500\n", "2019-05-27 13:45:25,594 INFO : valid_misclass 0.30000\n", "2019-05-27 13:45:25,596 INFO : runtime 1.75507\n", "2019-05-27 13:45:25,597 INFO : \n", "2019-05-27 13:45:26,624 INFO : Time only for training updates: 1.03s\n", "2019-05-27 13:45:27,406 INFO : Epoch 5\n", "2019-05-27 13:45:27,409 INFO : train_loss 0.29853\n", "2019-05-27 13:45:27,410 INFO : valid_loss 0.55171\n", "2019-05-27 13:45:27,412 INFO : train_misclass 0.20000\n", "2019-05-27 13:45:27,414 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:27,415 INFO : runtime 1.77113\n", "2019-05-27 13:45:27,417 INFO : \n", "2019-05-27 13:45:28,837 INFO : Time only for training updates: 1.42s\n", "2019-05-27 13:45:29,965 INFO : Epoch 6\n", "2019-05-27 13:45:29,968 INFO : train_loss 0.21160\n", "2019-05-27 13:45:29,970 INFO : valid_loss 0.50791\n", "2019-05-27 13:45:29,971 INFO : train_misclass 0.10000\n", "2019-05-27 13:45:29,973 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:29,975 INFO : runtime 2.21169\n", "2019-05-27 13:45:29,976 INFO : \n", "2019-05-27 13:45:31,564 INFO : Time only for training updates: 1.59s\n", "2019-05-27 13:45:32,295 INFO : Epoch 7\n", "2019-05-27 13:45:32,298 INFO : train_loss 0.13291\n", "2019-05-27 13:45:32,300 INFO : valid_loss 0.46993\n", "2019-05-27 13:45:32,302 INFO : train_misclass 0.05000\n", "2019-05-27 13:45:32,303 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:32,305 INFO : runtime 2.72866\n", "2019-05-27 13:45:32,307 INFO : \n", "2019-05-27 13:45:33,327 INFO : Time only for training updates: 1.02s\n", "2019-05-27 13:45:34,071 INFO : Epoch 8\n", "2019-05-27 13:45:34,074 INFO : train_loss 0.09791\n", "2019-05-27 13:45:34,076 INFO : valid_loss 0.45241\n", "2019-05-27 13:45:34,077 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:34,079 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:34,081 INFO : runtime 1.76224\n", "2019-05-27 13:45:34,082 INFO : \n", "2019-05-27 13:45:35,169 INFO : Time only for training updates: 1.08s\n", "2019-05-27 13:45:35,935 INFO : Epoch 9\n", "2019-05-27 13:45:35,938 INFO : train_loss 0.07745\n", "2019-05-27 13:45:35,940 INFO : valid_loss 0.45301\n", "2019-05-27 13:45:35,942 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:35,943 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:35,945 INFO : runtime 1.84197\n", "2019-05-27 13:45:35,947 INFO : \n", "2019-05-27 13:45:36,983 INFO : Time only for training updates: 1.03s\n", "2019-05-27 13:45:37,707 INFO : Epoch 10\n", "2019-05-27 13:45:37,710 INFO : train_loss 0.06647\n", "2019-05-27 13:45:37,711 INFO : valid_loss 0.45929\n", "2019-05-27 13:45:37,713 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:37,715 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:37,716 INFO : runtime 1.81326\n", "2019-05-27 13:45:37,718 INFO : \n", "2019-05-27 13:45:38,732 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:39,462 INFO : Epoch 11\n", "2019-05-27 13:45:39,465 INFO : train_loss 0.05948\n", "2019-05-27 13:45:39,467 INFO : valid_loss 0.47226\n", "2019-05-27 13:45:39,469 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:39,470 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:39,472 INFO : runtime 1.74926\n", "2019-05-27 13:45:39,473 INFO : \n", "2019-05-27 13:45:40,488 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:41,211 INFO : Epoch 12\n", "2019-05-27 13:45:41,214 INFO : train_loss 0.05326\n", "2019-05-27 13:45:41,215 INFO : valid_loss 0.48799\n", "2019-05-27 13:45:41,217 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:41,219 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:41,220 INFO : runtime 1.75707\n", "2019-05-27 13:45:41,222 INFO : \n", "2019-05-27 13:45:42,236 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:42,997 INFO : Epoch 13\n", "2019-05-27 13:45:43,000 INFO : train_loss 0.04709\n", "2019-05-27 13:45:43,001 INFO : valid_loss 0.50407\n", "2019-05-27 13:45:43,003 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:43,005 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:43,007 INFO : runtime 1.74734\n", "2019-05-27 13:45:43,009 INFO : \n", "2019-05-27 13:45:44,112 INFO : Time only for training updates: 1.10s\n", "2019-05-27 13:45:44,842 INFO : Epoch 14\n", "2019-05-27 13:45:44,844 INFO : train_loss 0.04247\n", "2019-05-27 13:45:44,846 INFO : valid_loss 0.50820\n", "2019-05-27 13:45:44,848 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:44,849 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:44,851 INFO : runtime 1.87550\n", "2019-05-27 13:45:44,853 INFO : \n", "2019-05-27 13:45:45,869 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:46,605 INFO : Epoch 15\n", "2019-05-27 13:45:46,608 INFO : train_loss 0.03919\n", "2019-05-27 13:45:46,610 INFO : valid_loss 0.51205\n", "2019-05-27 13:45:46,611 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:46,613 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:46,615 INFO : runtime 1.75688\n", "2019-05-27 13:45:46,616 INFO : \n", "2019-05-27 13:45:47,633 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:48,365 INFO : Epoch 16\n", "2019-05-27 13:45:48,368 INFO : train_loss 0.03620\n", "2019-05-27 13:45:48,370 INFO : valid_loss 0.51271\n", "2019-05-27 13:45:48,371 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:48,373 INFO : valid_misclass 0.20000\n", "2019-05-27 13:45:48,374 INFO : runtime 1.76364\n", "2019-05-27 13:45:48,376 INFO : \n", "2019-05-27 13:45:49,392 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:50,111 INFO : Epoch 17\n", "2019-05-27 13:45:50,114 INFO : train_loss 0.03297\n", "2019-05-27 13:45:50,116 INFO : valid_loss 0.50701\n", "2019-05-27 13:45:50,117 INFO : train_misclass 0.00000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2019-05-27 13:45:50,119 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:50,121 INFO : runtime 1.75996\n", "2019-05-27 13:45:50,123 INFO : \n", "2019-05-27 13:45:51,134 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:51,858 INFO : Epoch 18\n", "2019-05-27 13:45:51,860 INFO : train_loss 0.03033\n", "2019-05-27 13:45:51,862 INFO : valid_loss 0.50634\n", "2019-05-27 13:45:51,864 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:51,865 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:51,867 INFO : runtime 1.74143\n", "2019-05-27 13:45:51,869 INFO : \n", "2019-05-27 13:45:52,946 INFO : Time only for training updates: 1.08s\n", "2019-05-27 13:45:53,671 INFO : Epoch 19\n", "2019-05-27 13:45:53,674 INFO : train_loss 0.02815\n", "2019-05-27 13:45:53,675 INFO : valid_loss 0.50780\n", "2019-05-27 13:45:53,677 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:53,679 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:53,680 INFO : runtime 1.81164\n", "2019-05-27 13:45:53,682 INFO : \n", "2019-05-27 13:45:54,690 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:55,414 INFO : Epoch 20\n", "2019-05-27 13:45:55,417 INFO : train_loss 0.02614\n", "2019-05-27 13:45:55,418 INFO : valid_loss 0.51158\n", "2019-05-27 13:45:55,420 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:55,422 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:55,424 INFO : runtime 1.74491\n", "2019-05-27 13:45:55,425 INFO : \n", "2019-05-27 13:45:56,435 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:57,163 INFO : Epoch 21\n", "2019-05-27 13:45:57,166 INFO : train_loss 0.02446\n", "2019-05-27 13:45:57,168 INFO : valid_loss 0.51534\n", "2019-05-27 13:45:57,169 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:57,171 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:57,173 INFO : runtime 1.74495\n", "2019-05-27 13:45:57,174 INFO : \n", "2019-05-27 13:45:58,181 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:45:58,907 INFO : Epoch 22\n", "2019-05-27 13:45:58,910 INFO : train_loss 0.02319\n", "2019-05-27 13:45:58,911 INFO : valid_loss 0.51801\n", "2019-05-27 13:45:58,913 INFO : train_misclass 0.00000\n", "2019-05-27 13:45:58,915 INFO : valid_misclass 0.16667\n", "2019-05-27 13:45:58,916 INFO : runtime 1.74666\n", "2019-05-27 13:45:58,918 INFO : \n", "2019-05-27 13:45:59,961 INFO : Time only for training updates: 1.04s\n", "2019-05-27 13:46:00,709 INFO : Epoch 23\n", "2019-05-27 13:46:00,711 INFO : train_loss 0.02210\n", "2019-05-27 13:46:00,713 INFO : valid_loss 0.52047\n", "2019-05-27 13:46:00,715 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:00,716 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:00,718 INFO : runtime 1.77873\n", "2019-05-27 13:46:00,720 INFO : \n", "2019-05-27 13:46:01,750 INFO : Time only for training updates: 1.03s\n", "2019-05-27 13:46:02,475 INFO : Epoch 24\n", "2019-05-27 13:46:02,477 INFO : train_loss 0.02131\n", "2019-05-27 13:46:02,479 INFO : valid_loss 0.52308\n", "2019-05-27 13:46:02,481 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:02,482 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:02,484 INFO : runtime 1.78904\n", "2019-05-27 13:46:02,486 INFO : \n", "2019-05-27 13:46:03,495 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:46:04,222 INFO : Epoch 25\n", "2019-05-27 13:46:04,225 INFO : train_loss 0.02070\n", "2019-05-27 13:46:04,227 INFO : valid_loss 0.52473\n", "2019-05-27 13:46:04,228 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:04,230 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:04,232 INFO : runtime 1.74557\n", "2019-05-27 13:46:04,233 INFO : \n", "2019-05-27 13:46:05,245 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:46:05,973 INFO : Epoch 26\n", "2019-05-27 13:46:05,976 INFO : train_loss 0.02029\n", "2019-05-27 13:46:05,978 INFO : valid_loss 0.52578\n", "2019-05-27 13:46:05,979 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:05,981 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:05,983 INFO : runtime 1.74980\n", "2019-05-27 13:46:05,984 INFO : \n", "2019-05-27 13:46:07,009 INFO : Time only for training updates: 1.02s\n", "2019-05-27 13:46:07,731 INFO : Epoch 27\n", "2019-05-27 13:46:07,734 INFO : train_loss 0.02002\n", "2019-05-27 13:46:07,736 INFO : valid_loss 0.52576\n", "2019-05-27 13:46:07,737 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:07,739 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:07,741 INFO : runtime 1.76430\n", "2019-05-27 13:46:07,742 INFO : \n", "2019-05-27 13:46:08,749 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:46:09,489 INFO : Epoch 28\n", "2019-05-27 13:46:09,491 INFO : train_loss 0.01986\n", "2019-05-27 13:46:09,493 INFO : valid_loss 0.52516\n", "2019-05-27 13:46:09,495 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:09,496 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:09,498 INFO : runtime 1.73994\n", "2019-05-27 13:46:09,500 INFO : \n", "2019-05-27 13:46:10,512 INFO : Time only for training updates: 1.01s\n", "2019-05-27 13:46:11,233 INFO : Epoch 29\n", "2019-05-27 13:46:11,236 INFO : train_loss 0.01978\n", "2019-05-27 13:46:11,238 INFO : valid_loss 0.52423\n", "2019-05-27 13:46:11,239 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:11,241 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:11,243 INFO : runtime 1.76238\n", "2019-05-27 13:46:11,244 INFO : \n", "2019-05-27 13:46:12,273 INFO : Time only for training updates: 1.03s\n", "2019-05-27 13:46:12,997 INFO : Epoch 30\n", "2019-05-27 13:46:13,000 INFO : train_loss 0.01976\n", "2019-05-27 13:46:13,002 INFO : valid_loss 0.52316\n", "2019-05-27 13:46:13,004 INFO : train_misclass 0.00000\n", "2019-05-27 13:46:13,005 INFO : valid_misclass 0.16667\n", "2019-05-27 13:46:13,007 INFO : runtime 1.76154\n", "2019-05-27 13:46:13,009 INFO : \n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(train_set.X, train_set.y, epochs=30, batch_size=64, scheduler='cosine',\n", " validation_data=(valid_set.X, valid_set.y),)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The monitored values are also stored into a pandas dataframe:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
train_lossvalid_losstrain_misclassvalid_misclassruntime
05.3466495.1314540.4750.4666670.000000
11.1874721.4454520.4500.5333331.740984
20.9747341.2179270.4000.5000001.777023
30.6775760.9015260.3000.4000001.748683
40.4445740.6769240.2750.3000001.755075
50.2985280.5517150.2000.2000001.771130
60.2116010.5079130.1000.1666672.211694
70.1329090.4699270.0500.1666672.728657
80.0979080.4524080.0000.1666671.762245
90.0774460.4530120.0000.2000001.841966
100.0664680.4592920.0000.2000001.813261
110.0594770.4722590.0000.2000001.749265
120.0532630.4879930.0000.2000001.757071
130.0470880.5040740.0000.2000001.747344
140.0424680.5082040.0000.2000001.875502
150.0391860.5120520.0000.2000001.756878
160.0362010.5127070.0000.2000001.763641
170.0329710.5070120.0000.1666671.759965
180.0303310.5063420.0000.1666671.741430
190.0281490.5077960.0000.1666671.811639
200.0261430.5115780.0000.1666671.744909
210.0244600.5153350.0000.1666671.744946
220.0231930.5180140.0000.1666671.746660
230.0221020.5204730.0000.1666671.778730
240.0213060.5230840.0000.1666671.789039
250.0207030.5247320.0000.1666671.745571
260.0202860.5257790.0000.1666671.749797
270.0200230.5257610.0000.1666671.764305
280.0198610.5251600.0000.1666671.739945
290.0197830.5242320.0000.1666671.762381
300.0197640.5231570.0000.1666671.761538
\n", "
" ], "text/plain": [ " train_loss valid_loss train_misclass valid_misclass runtime\n", "0 5.346649 5.131454 0.475 0.466667 0.000000\n", "1 1.187472 1.445452 0.450 0.533333 1.740984\n", "2 0.974734 1.217927 0.400 0.500000 1.777023\n", "3 0.677576 0.901526 0.300 0.400000 1.748683\n", "4 0.444574 0.676924 0.275 0.300000 1.755075\n", "5 0.298528 0.551715 0.200 0.200000 1.771130\n", "6 0.211601 0.507913 0.100 0.166667 2.211694\n", "7 0.132909 0.469927 0.050 0.166667 2.728657\n", "8 0.097908 0.452408 0.000 0.166667 1.762245\n", "9 0.077446 0.453012 0.000 0.200000 1.841966\n", "10 0.066468 0.459292 0.000 0.200000 1.813261\n", "11 0.059477 0.472259 0.000 0.200000 1.749265\n", "12 0.053263 0.487993 0.000 0.200000 1.757071\n", "13 0.047088 0.504074 0.000 0.200000 1.747344\n", "14 0.042468 0.508204 0.000 0.200000 1.875502\n", "15 0.039186 0.512052 0.000 0.200000 1.756878\n", "16 0.036201 0.512707 0.000 0.200000 1.763641\n", "17 0.032971 0.507012 0.000 0.166667 1.759965\n", "18 0.030331 0.506342 0.000 0.166667 1.741430\n", "19 0.028149 0.507796 0.000 0.166667 1.811639\n", "20 0.026143 0.511578 0.000 0.166667 1.744909\n", "21 0.024460 0.515335 0.000 0.166667 1.744946\n", "22 0.023193 0.518014 0.000 0.166667 1.746660\n", "23 0.022102 0.520473 0.000 0.166667 1.778730\n", "24 0.021306 0.523084 0.000 0.166667 1.789039\n", "25 0.020703 0.524732 0.000 0.166667 1.745571\n", "26 0.020286 0.525779 0.000 0.166667 1.749797\n", "27 0.020023 0.525761 0.000 0.166667 1.764305\n", "28 0.019861 0.525160 0.000 0.166667 1.739945\n", "29 0.019783 0.524232 0.000 0.166667 1.762381\n", "30 0.019764 0.523157 0.000 0.166667 1.761538" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.epochs_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Eventually, we arrive at 83.4% accuracy, so 25 from 30 trials are correctly predicted. In the [Cropped Decoding Tutorial](./Cropped_Decoding.html), we can learn how to achieve higher accuracies using cropped training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'loss': 0.43049120903015137,\n", " 'misclass': 0.19999999999999996,\n", " 'runtime': 0.00031256675720214844}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_set = SignalAndTarget(X[70:], y=y[70:])\n", "\n", "model.evaluate(test_set.X, test_set.y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also retrieve predicted labels per trial as such:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_classes(test_set.X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also retrieve the raw network outputs per trial as such:\n", "\n", "
\n", "Note these are log-softmax outputs, so to get probabilities one would have to exponentiate them using `th.exp`.\n", "
" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-3.108049 , -0.0457173 ],\n", " [-0.18737975, -1.7668451 ],\n", " [-3.550216 , -0.02913891],\n", " [-0.00889281, -4.7269597 ],\n", " [-0.03029956, -3.5117326 ],\n", " [-0.00847233, -4.775178 ],\n", " [-4.006974 , -0.01835574],\n", " [-0.4073605 , -1.0948324 ],\n", " [-0.02217731, -3.819753 ],\n", " [-0.22672828, -1.5952262 ],\n", " [-3.5868406 , -0.02807612],\n", " [-1.3834822 , -0.28862125],\n", " [-0.32644472, -1.2782807 ],\n", " [-1.3229265 , -0.30972955],\n", " [-0.08954818, -2.4574184 ],\n", " [-0.0186951 , -3.9888284 ],\n", " [-0.09142652, -2.4375842 ],\n", " [-0.24392553, -1.5303771 ],\n", " [-0.03591001, -3.3446407 ],\n", " [-0.16686489, -1.8728433 ]], dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_outs(test_set.X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "If you want to try cross-subject decoding, changing the loading code to the following will perform cross-subject decoding on imagined left vs right hand closing, with 50 training and 5 validation subjects (Warning, might be very slow if you are on CPU):\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import mne\n", "import numpy as np\n", "from mne.io import concatenate_raws\n", "from braindecode.datautil.signal_target import SignalAndTarget\n", "\n", "# First 50 subjects as train\n", "physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]\n", "physionet_paths = np.concatenate(physionet_paths)\n", "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n", " for path in physionet_paths] \n", "\n", "raw = concatenate_raws(parts)\n", "\n", "picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n", " exclude='bads')\n", "\n", "# Find the events in this dataset\n", "events, _ = mne.events_from_annotations(raw)\n", "\n", "# Read epochs (train will be done only between 1 and 2s)\n", "# Testing will be done with a running classifier\n", "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,\n", " baseline=None, preload=True)\n", "\n", "# 51-55 as validation subjects\n", "physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]\n", "physionet_paths_valid = np.concatenate(physionet_paths_valid)\n", "parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n", " for path in physionet_paths_valid]\n", "raw_valid = concatenate_raws(parts_valid)\n", "\n", "picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,\n", " exclude='bads')\n", "\n", "events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')\n", "\n", "# Read epochs (train will be done only between 1 and 2s)\n", "# Testing will be done with a running classifier\n", "epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,\n", " baseline=None, preload=True)\n", "\n", "train_X = (epoched.get_data() * 1e6).astype(np.float32)\n", "train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n", "valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)\n", "valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n", "train_set = SignalAndTarget(train_X, y=train_y)\n", "valid_set = SignalAndTarget(valid_X, y=valid_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset references\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n", " \n", " Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n", "\n", "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n", "\n", " Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }