{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "nbsphinx": "hidden" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import os\n", "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Cropped Decoding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will use cropped decoding. Cropped decoding means the ConvNet is trained on time windows/time crops within the trials. We will explain this visually by comparing trialwise to cropped decoding.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Trialwise Decoding | Cropped Decoding\n", "- | - \n", "![Trialwise Decoding](./trialwise_explanation.png \"Trialwise Decoding\") | ![Cropped Decoding](./cropped_explanation.png \"Cropped Decoding\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On the left, you see trialwise decoding:\n", "\n", "1. A complete trial is pushed through the network\n", "2. The network produces a prediction\n", "3. The prediction is compared to the target (label) for that trial to compute the loss\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On the right, you see cropped decoding:\n", "\n", "1. Instead of a complete trial, windows within the trial, here called *crops*, are pushed through the network\n", "2. For computational efficiency, multiple neighbouring crops are pushed through the network simultaneously (these neighbouring crops are called a *supercrop*)\n", "3. Therefore, the network produces multiple predictions (one per crop in the supercrop)\n", "4. The individual crop predictions are averaged before computing the loss function\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notes:\n", "\n", "* The network architecture implicitly defines the crop size (it is the receptive field size, i.e., the number of timesteps the network uses to make a single prediction)\n", "* The supercrop size is a user-defined hyperparameter, called `input_time_length` in Braindecode. It mostly affects runtime (larger supercrop sizes should be faster). As a rule of thumb, you can set it to two times the crop size.\n", "* Crop size and supercrop size together define how many predictions the network makes per supercrop: $\\mathrm{\\#supercrop}-\\mathrm{\\#crop}+1=\\mathrm{\\#predictions}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For cropped decoding, the above training setup is mathematically identical to sampling crops in your dataset, pushing them through the network and training directly on the individual crops. At the same time, the above training setup is much faster as it avoids redundant computations by using dilated convolutions, see our paper [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051). However, the two setups are only mathematically identical in case (1) your network does not use any padding and (2) your loss function leads to the same gradients when using the averaged output. The first is true for our shallow and deep ConvNet models and the second is true for the log-softmax outputs and negative log likelihood loss that is typically used for classification in PyTorch." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Most of the code for cropped decoding is identical to the [Trialwise Decoding Tutorial](Trialwise_Decoding.html), differences are explained in the text." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Enable logging" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import logging\n", "import importlib\n", "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n", "log = logging.getLogger()\n", "log.setLevel('INFO')\n", "import sys\n", "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n", " level=logging.INFO, stream=sys.stdout)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import mne\n", "from mne.io import concatenate_raws\n", "\n", "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n", "subject_id = 22\n", "event_codes = [5,6,9,10,13,14]\n", "#event_codes = [3,4,5,6,7,8,9,10,11,12,13,14]\n", "\n", "# This will download the files if you don't have them yet,\n", "# and then return the paths to the files.\n", "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n", "\n", "# Load each of the files\n", "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n", " for path in physionet_paths]\n", "\n", "# Concatenate them\n", "raw = concatenate_raws(parts)\n", "\n", "# Find the events in this dataset\n", "events, _ = mne.events_from_annotations(raw)\n", "\n", "# Use only EEG channels\n", "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n", " exclude='bads')\n", "\n", "# Extract trials, only using EEG channels\n", "epoched = mne.Epochs(raw, events, dict(hands_or_left=2, feet_or_right=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n", " baseline=None, preload=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert data to Braindecode format" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "# Convert data from volt to millivolt\n", "# Pytorch expects float32 for input and int64 for labels.\n", "X = (epoched.get_data() * 1e6).astype(np.float32)\n", "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from braindecode.datautil.signal_target import SignalAndTarget\n", "\n", "train_set = SignalAndTarget(X[:40], y=y[:40])\n", "valid_set = SignalAndTarget(X[40:70], y=y[40:70])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", " | train_loss | \n", "valid_loss | \n", "train_misclass | \n", "valid_misclass | \n", "runtime | \n", "
---|---|---|---|---|---|
0 | \n", "16.718943 | \n", "16.038786 | \n", "0.525 | \n", "0.533333 | \n", "0.000000 | \n", "
1 | \n", "5.756664 | \n", "5.052062 | \n", "0.525 | \n", "0.533333 | \n", "6.956634 | \n", "
2 | \n", "3.647677 | \n", "3.016214 | \n", "0.500 | \n", "0.533333 | \n", "6.908031 | \n", "
3 | \n", "2.342561 | \n", "1.909614 | \n", "0.500 | \n", "0.533333 | \n", "7.240101 | \n", "
4 | \n", "1.506759 | \n", "1.290458 | \n", "0.475 | \n", "0.500000 | \n", "7.399405 | \n", "
5 | \n", "1.019237 | \n", "0.968446 | \n", "0.400 | \n", "0.433333 | \n", "7.065770 | \n", "
6 | \n", "0.758563 | \n", "0.821283 | \n", "0.350 | \n", "0.433333 | \n", "7.214281 | \n", "
7 | \n", "0.611557 | \n", "0.758545 | \n", "0.300 | \n", "0.433333 | \n", "5.655106 | \n", "
8 | \n", "0.518917 | \n", "0.740452 | \n", "0.250 | \n", "0.366667 | \n", "3.405795 | \n", "
9 | \n", "0.433832 | \n", "0.710416 | \n", "0.200 | \n", "0.300000 | \n", "3.383932 | \n", "
10 | \n", "0.342782 | \n", "0.652211 | \n", "0.175 | \n", "0.266667 | \n", "3.303478 | \n", "
11 | \n", "0.270821 | \n", "0.600610 | \n", "0.150 | \n", "0.266667 | \n", "3.260772 | \n", "
12 | \n", "0.218609 | \n", "0.563165 | \n", "0.025 | \n", "0.266667 | \n", "3.459012 | \n", "
13 | \n", "0.181991 | \n", "0.535691 | \n", "0.000 | \n", "0.233333 | \n", "3.346393 | \n", "
14 | \n", "0.154839 | \n", "0.508405 | \n", "0.000 | \n", "0.200000 | \n", "3.260915 | \n", "
15 | \n", "0.133887 | \n", "0.480797 | \n", "0.000 | \n", "0.166667 | \n", "3.267457 | \n", "
16 | \n", "0.117386 | \n", "0.454011 | \n", "0.000 | \n", "0.133333 | \n", "3.318306 | \n", "
17 | \n", "0.104724 | \n", "0.430165 | \n", "0.000 | \n", "0.166667 | \n", "3.256755 | \n", "
18 | \n", "0.094064 | \n", "0.406926 | \n", "0.000 | \n", "0.166667 | \n", "3.273476 | \n", "
19 | \n", "0.085207 | \n", "0.386002 | \n", "0.000 | \n", "0.133333 | \n", "3.268065 | \n", "
20 | \n", "0.077945 | \n", "0.367422 | \n", "0.000 | \n", "0.133333 | \n", "3.257596 | \n", "
21 | \n", "0.071709 | \n", "0.350847 | \n", "0.000 | \n", "0.100000 | \n", "3.971547 | \n", "
22 | \n", "0.066811 | \n", "0.337260 | \n", "0.000 | \n", "0.100000 | \n", "4.507347 | \n", "
23 | \n", "0.062891 | \n", "0.325995 | \n", "0.000 | \n", "0.100000 | \n", "6.996382 | \n", "
24 | \n", "0.059993 | \n", "0.317353 | \n", "0.000 | \n", "0.100000 | \n", "6.871776 | \n", "
25 | \n", "0.057849 | \n", "0.310633 | \n", "0.000 | \n", "0.100000 | \n", "6.769685 | \n", "
26 | \n", "0.056350 | \n", "0.305617 | \n", "0.000 | \n", "0.100000 | \n", "6.659091 | \n", "
27 | \n", "0.055315 | \n", "0.301796 | \n", "0.000 | \n", "0.100000 | \n", "6.849715 | \n", "
28 | \n", "0.054613 | \n", "0.298791 | \n", "0.000 | \n", "0.100000 | \n", "6.767460 | \n", "
29 | \n", "0.054149 | \n", "0.296389 | \n", "0.000 | \n", "0.100000 | \n", "6.771852 | \n", "
30 | \n", "0.053844 | \n", "0.294413 | \n", "0.000 | \n", "0.100000 | \n", "6.877121 | \n", "