Amplitude Perturbation Visualization¶
In this tutorial, we show how to use perturbations of the input amplitudes to learn something about the trained convolutional networks. For more background, see Deep learning with convolutional neural networks for EEG decoding and visualization, Section A.5.2.
First we will do some cross-subject decoding, again using the Physiobank EEG Motor Movement/Imagery Dataset, this time to decode imagined left hand vs. imagined right hand movement.
Enable logging¶
[2]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.INFO, stream=sys.stdout)
Load data¶
[3]:
import mne
import numpy as np
from mne.io import concatenate_raws
from braindecode.datautil.signal_target import SignalAndTarget
# First 50 subjects as train
physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]
physionet_paths = np.concatenate(physionet_paths)
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
for path in physionet_paths]
raw = concatenate_raws(parts)
picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
exclude='bads')
events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')
# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,
baseline=None, preload=True)
# 51-55 as validation subjects
physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]
physionet_paths_valid = np.concatenate(physionet_paths_valid)
parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
for path in physionet_paths_valid]
raw_valid = concatenate_raws(parts_valid)
picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,
exclude='bads')
events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')
# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,
baseline=None, preload=True)
train_X = (epoched.get_data() * 1e6).astype(np.float32)
train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)
valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
train_set = SignalAndTarget(train_X, y=train_y)
valid_set = SignalAndTarget(valid_X, y=valid_y)
Create the model¶
We use the deep ConvNet from Deep learning with convolutional neural networks for EEG decoding and visualization (Section 2.4.2).
[4]:
from braindecode.models.deep4 import Deep4Net
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = True
set_random_seeds(seed=20170629, cuda=cuda)
# This will determine how many crops are processed in parallel
input_time_length = 450
# final_conv_length determines the size of the receptive field of the ConvNet
model = Deep4Net(in_chans=64, n_classes=2, input_time_length=input_time_length,
filter_length_3=5, filter_length_4=5,
pool_time_stride=2,
stride_before_pool=True,
final_conv_length=1)
if cuda:
model.cuda()
[5]:
from braindecode.torch_ext.optimizers import AdamW
import torch.nn.functional as F
optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, cropped=True)
Run the training¶
[6]:
input_time_length = 450
model.fit(train_set.X, train_set.y, epochs=30, batch_size=64, scheduler='cosine',
input_time_length=input_time_length,
validation_data=(valid_set.X, valid_set.y),)
2018-08-27 12:04:33,164 INFO : Run until first stop...
2018-08-27 12:04:34,651 INFO : Epoch 0
2018-08-27 12:04:34,652 INFO : train_loss 3.30341
2018-08-27 12:04:34,655 INFO : valid_loss 3.40263
2018-08-27 12:04:34,657 INFO : train_misclass 0.51359
2018-08-27 12:04:34,658 INFO : valid_misclass 0.53052
2018-08-27 12:04:34,660 INFO : runtime 0.00000
2018-08-27 12:04:34,662 INFO :
2018-08-27 12:04:36,947 INFO : Time only for training updates: 1.89s
2018-08-27 12:04:38,297 INFO : Epoch 1
2018-08-27 12:04:38,298 INFO : train_loss 0.67976
2018-08-27 12:04:38,299 INFO : valid_loss 0.61942
2018-08-27 12:04:38,300 INFO : train_misclass 0.41846
2018-08-27 12:04:38,301 INFO : valid_misclass 0.30516
2018-08-27 12:04:38,302 INFO : runtime 3.78315
2018-08-27 12:04:38,303 INFO :
2018-08-27 12:04:40,568 INFO : Time only for training updates: 1.88s
2018-08-27 12:04:41,919 INFO : Epoch 2
2018-08-27 12:04:41,920 INFO : train_loss 0.64919
2018-08-27 12:04:41,921 INFO : valid_loss 0.56678
2018-08-27 12:04:41,922 INFO : train_misclass 0.34114
2018-08-27 12:04:41,924 INFO : valid_misclass 0.21127
2018-08-27 12:04:41,925 INFO : runtime 3.62068
2018-08-27 12:04:41,926 INFO :
2018-08-27 12:04:44,187 INFO : Time only for training updates: 1.88s
2018-08-27 12:04:45,484 INFO : Epoch 3
2018-08-27 12:04:45,485 INFO : train_loss 0.63218
2018-08-27 12:04:45,486 INFO : valid_loss 0.54783
2018-08-27 12:04:45,487 INFO : train_misclass 0.31584
2018-08-27 12:04:45,488 INFO : valid_misclass 0.16432
2018-08-27 12:04:45,489 INFO : runtime 3.61918
2018-08-27 12:04:45,490 INFO :
2018-08-27 12:04:47,751 INFO : Time only for training updates: 1.88s
2018-08-27 12:04:49,053 INFO : Epoch 4
2018-08-27 12:04:49,054 INFO : train_loss 0.61976
2018-08-27 12:04:49,055 INFO : valid_loss 0.52507
2018-08-27 12:04:49,056 INFO : train_misclass 0.29475
2018-08-27 12:04:49,057 INFO : valid_misclass 0.15962
2018-08-27 12:04:49,058 INFO : runtime 3.56433
2018-08-27 12:04:49,060 INFO :
2018-08-27 12:04:51,321 INFO : Time only for training updates: 1.88s
2018-08-27 12:04:52,620 INFO : Epoch 5
2018-08-27 12:04:52,621 INFO : train_loss 0.61424
2018-08-27 12:04:52,622 INFO : valid_loss 0.51823
2018-08-27 12:04:52,624 INFO : train_misclass 0.29053
2018-08-27 12:04:52,625 INFO : valid_misclass 0.15023
2018-08-27 12:04:52,626 INFO : runtime 3.56998
2018-08-27 12:04:52,627 INFO :
2018-08-27 12:04:54,927 INFO : Time only for training updates: 1.92s
2018-08-27 12:04:56,223 INFO : Epoch 6
2018-08-27 12:04:56,224 INFO : train_loss 0.64316
2018-08-27 12:04:56,225 INFO : valid_loss 0.54869
2018-08-27 12:04:56,226 INFO : train_misclass 0.32755
2018-08-27 12:04:56,228 INFO : valid_misclass 0.19249
2018-08-27 12:04:56,229 INFO : runtime 3.60571
2018-08-27 12:04:56,230 INFO :
2018-08-27 12:04:58,492 INFO : Time only for training updates: 1.88s
2018-08-27 12:04:59,791 INFO : Epoch 7
2018-08-27 12:04:59,792 INFO : train_loss 0.60664
2018-08-27 12:04:59,793 INFO : valid_loss 0.51988
2018-08-27 12:04:59,794 INFO : train_misclass 0.29100
2018-08-27 12:04:59,795 INFO : valid_misclass 0.11268
2018-08-27 12:04:59,796 INFO : runtime 3.56507
2018-08-27 12:04:59,798 INFO :
2018-08-27 12:05:02,061 INFO : Time only for training updates: 1.88s
2018-08-27 12:05:03,358 INFO : Epoch 8
2018-08-27 12:05:03,359 INFO : train_loss 0.61118
2018-08-27 12:05:03,360 INFO : valid_loss 0.52178
2018-08-27 12:05:03,361 INFO : train_misclass 0.27179
2018-08-27 12:05:03,362 INFO : valid_misclass 0.12207
2018-08-27 12:05:03,364 INFO : runtime 3.56884
2018-08-27 12:05:03,365 INFO :
2018-08-27 12:05:05,627 INFO : Time only for training updates: 1.88s
2018-08-27 12:05:06,930 INFO : Epoch 9
2018-08-27 12:05:06,931 INFO : train_loss 0.61678
2018-08-27 12:05:06,935 INFO : valid_loss 0.52481
2018-08-27 12:05:06,936 INFO : train_misclass 0.29428
2018-08-27 12:05:06,939 INFO : valid_misclass 0.14085
2018-08-27 12:05:06,939 INFO : runtime 3.56590
2018-08-27 12:05:06,943 INFO :
2018-08-27 12:05:09,208 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:10,507 INFO : Epoch 10
2018-08-27 12:05:10,508 INFO : train_loss 0.61286
2018-08-27 12:05:10,511 INFO : valid_loss 0.51615
2018-08-27 12:05:10,512 INFO : train_misclass 0.28913
2018-08-27 12:05:10,513 INFO : valid_misclass 0.14554
2018-08-27 12:05:10,515 INFO : runtime 3.58118
2018-08-27 12:05:10,516 INFO :
2018-08-27 12:05:12,785 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:14,086 INFO : Epoch 11
2018-08-27 12:05:14,087 INFO : train_loss 0.59369
2018-08-27 12:05:14,091 INFO : valid_loss 0.50478
2018-08-27 12:05:14,091 INFO : train_misclass 0.27976
2018-08-27 12:05:14,095 INFO : valid_misclass 0.13615
2018-08-27 12:05:14,095 INFO : runtime 3.57760
2018-08-27 12:05:14,099 INFO :
2018-08-27 12:05:16,369 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:17,691 INFO : Epoch 12
2018-08-27 12:05:17,692 INFO : train_loss 0.60356
2018-08-27 12:05:17,693 INFO : valid_loss 0.53101
2018-08-27 12:05:17,694 INFO : train_misclass 0.26945
2018-08-27 12:05:17,695 INFO : valid_misclass 0.17840
2018-08-27 12:05:17,697 INFO : runtime 3.58372
2018-08-27 12:05:17,698 INFO :
2018-08-27 12:05:19,976 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:21,278 INFO : Epoch 13
2018-08-27 12:05:21,279 INFO : train_loss 0.60062
2018-08-27 12:05:21,280 INFO : valid_loss 0.47819
2018-08-27 12:05:21,281 INFO : train_misclass 0.26148
2018-08-27 12:05:21,282 INFO : valid_misclass 0.13615
2018-08-27 12:05:21,283 INFO : runtime 3.60660
2018-08-27 12:05:21,284 INFO :
2018-08-27 12:05:23,559 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:24,876 INFO : Epoch 14
2018-08-27 12:05:24,877 INFO : train_loss 0.58935
2018-08-27 12:05:24,879 INFO : valid_loss 0.48233
2018-08-27 12:05:24,880 INFO : train_misclass 0.26101
2018-08-27 12:05:24,881 INFO : valid_misclass 0.14085
2018-08-27 12:05:24,882 INFO : runtime 3.58363
2018-08-27 12:05:24,883 INFO :
2018-08-27 12:05:27,163 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:28,477 INFO : Epoch 15
2018-08-27 12:05:28,478 INFO : train_loss 0.58065
2018-08-27 12:05:28,479 INFO : valid_loss 0.49514
2018-08-27 12:05:28,480 INFO : train_misclass 0.24977
2018-08-27 12:05:28,481 INFO : valid_misclass 0.12676
2018-08-27 12:05:28,481 INFO : runtime 3.60388
2018-08-27 12:05:28,482 INFO :
2018-08-27 12:05:30,756 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:32,070 INFO : Epoch 16
2018-08-27 12:05:32,071 INFO : train_loss 0.57086
2018-08-27 12:05:32,072 INFO : valid_loss 0.48204
2018-08-27 12:05:32,073 INFO : train_misclass 0.24321
2018-08-27 12:05:32,073 INFO : valid_misclass 0.09390
2018-08-27 12:05:32,074 INFO : runtime 3.59321
2018-08-27 12:05:32,074 INFO :
2018-08-27 12:05:34,346 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:35,652 INFO : Epoch 17
2018-08-27 12:05:35,653 INFO : train_loss 0.58726
2018-08-27 12:05:35,653 INFO : valid_loss 0.49676
2018-08-27 12:05:35,654 INFO : train_misclass 0.25070
2018-08-27 12:05:35,655 INFO : valid_misclass 0.16432
2018-08-27 12:05:35,655 INFO : runtime 3.58996
2018-08-27 12:05:35,656 INFO :
2018-08-27 12:05:37,934 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:39,250 INFO : Epoch 18
2018-08-27 12:05:39,252 INFO : train_loss 0.56666
2018-08-27 12:05:39,252 INFO : valid_loss 0.46518
2018-08-27 12:05:39,253 INFO : train_misclass 0.23711
2018-08-27 12:05:39,254 INFO : valid_misclass 0.12676
2018-08-27 12:05:39,254 INFO : runtime 3.58806
2018-08-27 12:05:39,255 INFO :
2018-08-27 12:05:41,566 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:42,999 INFO : Epoch 19
2018-08-27 12:05:43,000 INFO : train_loss 0.55883
2018-08-27 12:05:43,000 INFO : valid_loss 0.47641
2018-08-27 12:05:43,001 INFO : train_misclass 0.23805
2018-08-27 12:05:43,005 INFO : valid_misclass 0.10798
2018-08-27 12:05:43,006 INFO : runtime 3.63181
2018-08-27 12:05:43,007 INFO :
2018-08-27 12:05:45,276 INFO : Time only for training updates: 1.89s
2018-08-27 12:05:46,646 INFO : Epoch 20
2018-08-27 12:05:46,647 INFO : train_loss 0.55717
2018-08-27 12:05:46,649 INFO : valid_loss 0.48102
2018-08-27 12:05:46,650 INFO : train_misclass 0.22259
2018-08-27 12:05:46,651 INFO : valid_misclass 0.11737
2018-08-27 12:05:46,652 INFO : runtime 3.71030
2018-08-27 12:05:46,653 INFO :
2018-08-27 12:05:48,932 INFO : Time only for training updates: 1.90s
2018-08-27 12:05:50,236 INFO : Epoch 21
2018-08-27 12:05:50,237 INFO : train_loss 0.55138
2018-08-27 12:05:50,238 INFO : valid_loss 0.47634
2018-08-27 12:05:50,240 INFO : train_misclass 0.21603
2018-08-27 12:05:50,241 INFO : valid_misclass 0.09390
2018-08-27 12:05:50,242 INFO : runtime 3.65521
2018-08-27 12:05:50,243 INFO :
2018-08-27 12:05:52,521 INFO : Time only for training updates: 1.90s
2018-08-27 12:05:53,862 INFO : Epoch 22
2018-08-27 12:05:53,863 INFO : train_loss 0.54738
2018-08-27 12:05:53,864 INFO : valid_loss 0.46885
2018-08-27 12:05:53,866 INFO : train_misclass 0.22165
2018-08-27 12:05:53,867 INFO : valid_misclass 0.08920
2018-08-27 12:05:53,868 INFO : runtime 3.58897
2018-08-27 12:05:53,869 INFO :
2018-08-27 12:05:56,148 INFO : Time only for training updates: 1.90s
2018-08-27 12:05:57,515 INFO : Epoch 23
2018-08-27 12:05:57,517 INFO : train_loss 0.53695
2018-08-27 12:05:57,519 INFO : valid_loss 0.45847
2018-08-27 12:05:57,520 INFO : train_misclass 0.21087
2018-08-27 12:05:57,521 INFO : valid_misclass 0.10798
2018-08-27 12:05:57,522 INFO : runtime 3.62732
2018-08-27 12:05:57,524 INFO :
2018-08-27 12:05:59,804 INFO : Time only for training updates: 1.90s
2018-08-27 12:06:01,105 INFO : Epoch 24
2018-08-27 12:06:01,106 INFO : train_loss 0.53602
2018-08-27 12:06:01,107 INFO : valid_loss 0.46697
2018-08-27 12:06:01,108 INFO : train_misclass 0.20993
2018-08-27 12:06:01,109 INFO : valid_misclass 0.11268
2018-08-27 12:06:01,111 INFO : runtime 3.65615
2018-08-27 12:06:01,112 INFO :
2018-08-27 12:06:03,392 INFO : Time only for training updates: 1.90s
2018-08-27 12:06:04,690 INFO : Epoch 25
2018-08-27 12:06:04,691 INFO : train_loss 0.52928
2018-08-27 12:06:04,693 INFO : valid_loss 0.46500
2018-08-27 12:06:04,694 INFO : train_misclass 0.21134
2018-08-27 12:06:04,695 INFO : valid_misclass 0.12676
2018-08-27 12:06:04,696 INFO : runtime 3.58741
2018-08-27 12:06:04,697 INFO :
2018-08-27 12:06:06,978 INFO : Time only for training updates: 1.90s
2018-08-27 12:06:08,274 INFO : Epoch 26
2018-08-27 12:06:08,275 INFO : train_loss 0.53385
2018-08-27 12:06:08,276 INFO : valid_loss 0.45885
2018-08-27 12:06:08,277 INFO : train_misclass 0.21649
2018-08-27 12:06:08,278 INFO : valid_misclass 0.12207
2018-08-27 12:06:08,279 INFO : runtime 3.58615
2018-08-27 12:06:08,280 INFO :
2018-08-27 12:06:10,560 INFO : Time only for training updates: 1.90s
2018-08-27 12:06:11,854 INFO : Epoch 27
2018-08-27 12:06:11,855 INFO : train_loss 0.52531
2018-08-27 12:06:11,856 INFO : valid_loss 0.45631
2018-08-27 12:06:11,857 INFO : train_misclass 0.20337
2018-08-27 12:06:11,858 INFO : valid_misclass 0.12207
2018-08-27 12:06:11,860 INFO : runtime 3.58260
2018-08-27 12:06:11,861 INFO :
2018-08-27 12:06:14,147 INFO : Time only for training updates: 1.90s
2018-08-27 12:06:15,447 INFO : Epoch 28
2018-08-27 12:06:15,448 INFO : train_loss 0.52515
2018-08-27 12:06:15,450 INFO : valid_loss 0.45932
2018-08-27 12:06:15,451 INFO : train_misclass 0.20197
2018-08-27 12:06:15,452 INFO : valid_misclass 0.11268
2018-08-27 12:06:15,453 INFO : runtime 3.58636
2018-08-27 12:06:15,454 INFO :
2018-08-27 12:06:17,739 INFO : Time only for training updates: 1.90s
2018-08-27 12:06:19,094 INFO : Epoch 29
2018-08-27 12:06:19,095 INFO : train_loss 0.52143
2018-08-27 12:06:19,099 INFO : valid_loss 0.45449
2018-08-27 12:06:19,099 INFO : train_misclass 0.19541
2018-08-27 12:06:19,103 INFO : valid_misclass 0.11737
2018-08-27 12:06:19,103 INFO : runtime 3.59220
2018-08-27 12:06:19,106 INFO :
2018-08-27 12:06:21,387 INFO : Time only for training updates: 1.90s
2018-08-27 12:06:22,693 INFO : Epoch 30
2018-08-27 12:06:22,694 INFO : train_loss 0.51937
2018-08-27 12:06:22,698 INFO : valid_loss 0.45642
2018-08-27 12:06:22,698 INFO : train_misclass 0.19588
2018-08-27 12:06:22,702 INFO : valid_misclass 0.11737
2018-08-27 12:06:22,702 INFO : runtime 3.64778
2018-08-27 12:06:22,704 INFO :
[6]:
<braindecode.experiments.experiment.Experiment at 0x7f0a08f9c2b0>
Compute correlation: amplitude perturbation - prediction change¶
First collect all batches and concatenate them into one array of examples:
[7]:
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.torch_ext.util import np_to_var
test_input = np_to_var(np.ones((2, 64, input_time_length, 1), dtype=np.float32))
if cuda:
test_input = test_input.cuda()
out = model.network(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input)
train_batches = list(iterator.get_batches(train_set, shuffle=False))
train_X_batches = np.concatenate(list(zip(*train_batches))[0])
Next, create a prediction function that wraps the model prediction function and returns the predictions as numpy arrays. We use the predition before the softmax, so we create a new module with all the layers of the old until before the softmax.
[8]:
from torch import nn
from braindecode.torch_ext.util import var_to_np
import torch as th
new_model = nn.Sequential()
for name, module in model.network.named_children():
if name == 'softmax': break
new_model.add_module(name, module)
new_model.eval();
pred_fn = lambda x: var_to_np(th.mean(new_model(np_to_var(x).cuda())[:,:,:,0], dim=2, keepdim=False))
[16]:
from braindecode.visualization.perturbation import compute_amplitude_prediction_correlations
amp_pred_corrs = compute_amplitude_prediction_correlations(pred_fn, train_X_batches, n_iterations=12,
batch_size=30)
2018-08-08 12:19:10,049 INFO : Compute original predictions...
2018-08-08 12:19:27,871 INFO : Iteration 0...
2018-08-08 12:19:27,873 INFO : Sample perturbation...
2018-08-08 12:19:37,305 INFO : Compute perturbed complex inputs...
2018-08-08 12:19:50,788 INFO : Compute perturbed real inputs...
2018-08-08 12:20:08,863 INFO : Compute new predictions...
2018-08-08 12:20:09,609 INFO : Layer 0...
2018-08-08 12:20:09,610 INFO : Compute activation difference...
2018-08-08 12:20:09,611 INFO : Compute correlation...
2018-08-08 12:20:13,067 INFO : Iteration 1...
2018-08-08 12:20:13,068 INFO : Sample perturbation...
2018-08-08 12:20:20,858 INFO : Compute perturbed complex inputs...
2018-08-08 12:20:30,561 INFO : Compute perturbed real inputs...
2018-08-08 12:20:45,390 INFO : Compute new predictions...
2018-08-08 12:20:46,164 INFO : Layer 0...
2018-08-08 12:20:46,165 INFO : Compute activation difference...
2018-08-08 12:20:46,166 INFO : Compute correlation...
2018-08-08 12:20:49,880 INFO : Iteration 2...
2018-08-08 12:20:49,880 INFO : Sample perturbation...
2018-08-08 12:20:59,277 INFO : Compute perturbed complex inputs...
2018-08-08 12:21:11,041 INFO : Compute perturbed real inputs...
2018-08-08 12:21:32,809 INFO : Compute new predictions...
2018-08-08 12:21:33,628 INFO : Layer 0...
2018-08-08 12:21:33,629 INFO : Compute activation difference...
2018-08-08 12:21:33,631 INFO : Compute correlation...
2018-08-08 12:21:37,634 INFO : Iteration 3...
2018-08-08 12:21:37,635 INFO : Sample perturbation...
2018-08-08 12:21:47,688 INFO : Compute perturbed complex inputs...
2018-08-08 12:21:59,979 INFO : Compute perturbed real inputs...
2018-08-08 12:22:21,661 INFO : Compute new predictions...
2018-08-08 12:22:22,430 INFO : Layer 0...
2018-08-08 12:22:22,431 INFO : Compute activation difference...
2018-08-08 12:22:22,432 INFO : Compute correlation...
2018-08-08 12:22:26,532 INFO : Iteration 4...
2018-08-08 12:22:26,533 INFO : Sample perturbation...
2018-08-08 12:22:36,511 INFO : Compute perturbed complex inputs...
2018-08-08 12:22:48,632 INFO : Compute perturbed real inputs...
2018-08-08 12:23:05,148 INFO : Compute new predictions...
2018-08-08 12:23:05,920 INFO : Layer 0...
2018-08-08 12:23:05,921 INFO : Compute activation difference...
2018-08-08 12:23:05,922 INFO : Compute correlation...
2018-08-08 12:23:08,494 INFO : Iteration 5...
2018-08-08 12:23:08,495 INFO : Sample perturbation...
2018-08-08 12:23:17,066 INFO : Compute perturbed complex inputs...
2018-08-08 12:23:29,459 INFO : Compute perturbed real inputs...
2018-08-08 12:23:51,135 INFO : Compute new predictions...
2018-08-08 12:23:51,910 INFO : Layer 0...
2018-08-08 12:23:51,911 INFO : Compute activation difference...
2018-08-08 12:23:51,912 INFO : Compute correlation...
2018-08-08 12:23:52,788 INFO : Iteration 6...
2018-08-08 12:23:52,789 INFO : Sample perturbation...
2018-08-08 12:23:59,554 INFO : Compute perturbed complex inputs...
2018-08-08 12:24:11,074 INFO : Compute perturbed real inputs...
2018-08-08 12:24:31,349 INFO : Compute new predictions...
2018-08-08 12:24:32,134 INFO : Layer 0...
2018-08-08 12:24:32,135 INFO : Compute activation difference...
2018-08-08 12:24:32,135 INFO : Compute correlation...
2018-08-08 12:24:36,217 INFO : Iteration 7...
2018-08-08 12:24:36,218 INFO : Sample perturbation...
2018-08-08 12:24:45,191 INFO : Compute perturbed complex inputs...
2018-08-08 12:24:55,820 INFO : Compute perturbed real inputs...
2018-08-08 12:25:13,979 INFO : Compute new predictions...
2018-08-08 12:25:14,761 INFO : Layer 0...
2018-08-08 12:25:14,762 INFO : Compute activation difference...
2018-08-08 12:25:14,763 INFO : Compute correlation...
2018-08-08 12:25:17,312 INFO : Iteration 8...
2018-08-08 12:25:17,313 INFO : Sample perturbation...
2018-08-08 12:25:25,635 INFO : Compute perturbed complex inputs...
2018-08-08 12:25:37,586 INFO : Compute perturbed real inputs...
2018-08-08 12:25:59,839 INFO : Compute new predictions...
2018-08-08 12:26:00,618 INFO : Layer 0...
2018-08-08 12:26:00,619 INFO : Compute activation difference...
2018-08-08 12:26:00,620 INFO : Compute correlation...
2018-08-08 12:26:01,544 INFO : Iteration 9...
2018-08-08 12:26:01,545 INFO : Sample perturbation...
2018-08-08 12:26:07,447 INFO : Compute perturbed complex inputs...
2018-08-08 12:26:17,042 INFO : Compute perturbed real inputs...
2018-08-08 12:26:31,921 INFO : Compute new predictions...
2018-08-08 12:26:32,650 INFO : Layer 0...
2018-08-08 12:26:32,651 INFO : Compute activation difference...
2018-08-08 12:26:32,653 INFO : Compute correlation...
2018-08-08 12:26:35,548 INFO : Iteration 10...
2018-08-08 12:26:35,549 INFO : Sample perturbation...
2018-08-08 12:26:42,538 INFO : Compute perturbed complex inputs...
2018-08-08 12:26:51,004 INFO : Compute perturbed real inputs...
2018-08-08 12:27:04,210 INFO : Compute new predictions...
2018-08-08 12:27:04,980 INFO : Layer 0...
2018-08-08 12:27:04,981 INFO : Compute activation difference...
2018-08-08 12:27:04,982 INFO : Compute correlation...
2018-08-08 12:27:07,018 INFO : Iteration 11...
2018-08-08 12:27:07,019 INFO : Sample perturbation...
2018-08-08 12:27:14,177 INFO : Compute perturbed complex inputs...
2018-08-08 12:27:24,518 INFO : Compute perturbed real inputs...
2018-08-08 12:27:40,668 INFO : Compute new predictions...
2018-08-08 12:27:41,417 INFO : Layer 0...
2018-08-08 12:27:41,418 INFO : Compute activation difference...
2018-08-08 12:27:41,419 INFO : Compute correlation...
Plot correlations¶
Pick out one frequency range and mean correlations within that frequency range to make a scalp plot. Here we use the alpha frequency range.
[19]:
amp_pred_corrs.shape
[19]:
(64, 226, 2)
[20]:
fs = epoched.info['sfreq']
freqs = np.fft.rfftfreq(train_X_batches.shape[2], d=1.0/fs)
start_freq = 7
stop_freq = 14
i_start = np.searchsorted(freqs,start_freq)
i_stop = np.searchsorted(freqs, stop_freq) + 1
freq_corr = np.mean(amp_pred_corrs[:,i_start:i_stop], axis=1)
Now get approximate positions of the channels in the 10-20 system.
[21]:
from braindecode.datasets.sensor_positions import get_channelpos, CHANNEL_10_20_APPROX
ch_names = [s.strip('.') for s in epoched.ch_names]
positions = [get_channelpos(name, CHANNEL_10_20_APPROX) for name in ch_names]
positions = np.array(positions)
Plot with MNE¶
[22]:
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline
max_abs_val = np.max(np.abs(freq_corr))
[23]:
fig, axes = plt.subplots(1, 2)
class_names = ['Left Hand', 'Right Hand']
for i_class in range(2):
ax = axes[i_class]
mne.viz.plot_topomap(freq_corr[:,i_class], positions,
vmin=-max_abs_val, vmax=max_abs_val, contours=0,
cmap=cm.coolwarm, axes=ax, show=False);
ax.set_title(class_names[i_class])
Plot with Braindecode¶
[24]:
from braindecode.visualization.plot import ax_scalp
fig, axes = plt.subplots(1, 2)
class_names = ['Left Hand', 'Right Hand']
for i_class in range(2):
ax = axes[i_class]
ax_scalp(freq_corr[:,i_class], ch_names, chan_pos_list=CHANNEL_10_20_APPROX, cmap=cm.coolwarm,
vmin=-max_abs_val, vmax=max_abs_val, ax=ax)
ax.set_title(class_names[i_class])
From these plots we can see the ConvNet clearly learned to use the lateralized response in the alpha band. Note that the positive correlations for the left hand on the left side do not imply an increase of alpha activity for the left hand in the data, see Deep learning with convolutional neural networks for EEG decoding and visualization Result 12 for some notes on interpretability.
Dataset references¶
This dataset was created and contributed to PhysioNet by the developers of the BCI2000 instrumentation system, which they used in making these recordings. The system is described in:
Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.
PhysioBank is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:
Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220.