braindecode.experiments package

Convenience classes for experiments, including monitoring and stop criteria.

Submodules

braindecode.experiments.experiment module

class braindecode.experiments.experiment.RememberBest(column_name)[source]

Bases: object

Class to remember and restore the parameters of the model and the parameters of the optimizer at the epoch with the best performance.

Parameters:column_name (str) – The lowest value in this column should indicate the epoch with the best performance (e.g. misclass might make sense).
Variables:best_epoch (int) – Index of best epoch
remember_epoch(epochs_df, model, optimizer)[source]

Remember this epoch: Remember parameter values in case this epoch has the best performance so far.

Parameters:
  • epochs_df (pandas.Dataframe) – Dataframe containing the column column_name with which performance is evaluated.
  • model (torch.nn.Module)
  • optimizer (torch.optim.Optimizer)
reset_to_best_model(epochs_df, model, optimizer)[source]

Reset parameters to parameters at best epoch and remove rows after best epoch from epochs dataframe.

Modifies parameters of model and optimizer, changes epochs_df in-place.

Parameters:
  • epochs_df (pandas.Dataframe)
  • model (torch.nn.Module)
  • optimizer (torch.optim.Optimizer)
class braindecode.experiments.experiment.Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint, monitors, stop_criterion, remember_best_column, run_after_early_stop, model_loss_function=None, batch_modifier=None, cuda=True, pin_memory=False, do_early_stop=True, reset_after_second_run=False, log_0_epoch=True, loggers=('print', ))[source]

Bases: object

Class that performs one experiment on training, validation and test set.

It trains as follows:

  1. Train on training set until a given stop criterion is fulfilled
  2. Reset to the best epoch, i.e. reset parameters of the model and the optimizer to the state at the best epoch (“best” according to a given criterion)
  3. Continue training on the combined training + validation set until the loss on the validation set is as low as it was on the best epoch for the training set. (or until the ConvNet was trained twice as many epochs as the best epoch to prevent infinite training)
Parameters:
  • model (torch.nn.Module)
  • train_set (SignalAndTarget)
  • valid_set (SignalAndTarget)
  • test_set (SignalAndTarget)
  • iterator (iterator object)
  • loss_function (function) – Function mapping predictions and targets to a loss: (predictions: torch.autograd.Variable, targets:torch.autograd.Variable) -> loss: torch.autograd.Variable
  • optimizer (torch.optim.Optimizer)
  • model_constraint (object) – Object with apply function that takes model and constraints its parameters. None for no constraint.
  • monitors (list of objects) – List of objects with monitor_epoch and monitor_set method, should monitor the traning progress.
  • stop_criterion (object) – Object with should_stop method, that takes in monitoring dataframe and returns if training should stop:
  • remember_best_column (str) – Name of column to use for storing parameters of best model. Lowest value should indicate best performance in this column.
  • run_after_early_stop (bool) – Whether to continue running after early stop
  • model_loss_function (function, optional) – Function (model -> loss) to add a model loss like L2 regularization. Note that this loss is not accounted for in monitoring at the moment.
  • batch_modifier (object, optional) – Object with modify method, that can change the batch, e.g. for data augmentation
  • cuda (bool, optional) – Whether to use cuda.
  • pin_memory (bool, optional) – Whether to pin memory of inputs and targets of batch.
  • do_early_stop (bool) – Whether to do an early stop at all. If true, reset to best model even in case experiment does not run after early stop.
  • reset_after_second_run (bool) – If true, reset to best model when second run did not find a valid loss below or equal to the best train loss of first run.
  • log_0_epoch (bool) – Whether to compute monitor values and log them before the start of training.
  • loggers (list of Logger) – How to show computed metrics.
Variables:

epochs_df (pandas.DataFrame) – Monitoring values for all epochs.

run()[source]

Run complete training.

setup_training()[source]

Setup training, i.e. transform model to cuda, initialize monitoring.

run_until_first_stop()[source]

Run training and evaluation using only training set for training until stop criterion is fulfilled.

run_until_second_stop()[source]

Run training and evaluation using combined training + validation set for training.

Runs until loss on validation set decreases below loss on training set of best epoch or until as many epochs trained after as before first stop.

run_until_stop(datasets, remember_best)[source]

Run training and evaluation on given datasets until stop criterion is fulfilled.

Parameters:
  • datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to SignalAndTarget objects.
  • remember_best (bool) – Whether to remember parameters at best epoch.
run_one_epoch(datasets, remember_best)[source]

Run training and evaluation on given datasets for one epoch.

Parameters:
  • datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to SignalAndTarget objects.
  • remember_best (bool) – Whether to remember parameters if this epoch is best epoch.
train_batch(inputs, targets)[source]

Train on given inputs and targets.

Parameters:
  • inputs (torch.autograd.Variable)
  • targets (torch.autograd.Variable)
eval_on_batch(inputs, targets)[source]

Evaluate given inputs and targets.

Parameters:
  • inputs (torch.autograd.Variable)
  • targets (torch.autograd.Variable)
Returns:

  • predictions (torch.autograd.Variable)
  • loss (torch.autograd.Variable)

monitor_epoch(datasets)[source]

Evaluate one epoch for given datasets.

Stores results in epochs_df

Parameters:datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to SignalAndTarget objects.
log_epoch()[source]

Print monitoring values for this epoch.

setup_after_stop_training()[source]

Setup training after first stop.

Resets parameters to best parameters and updates stop criterion.

braindecode.experiments.loggers module

class braindecode.experiments.loggers.Logger[source]

Bases: abc.ABC

log_epoch(epochs_df)[source]
class braindecode.experiments.loggers.Printer[source]

Bases: braindecode.experiments.loggers.Logger

Prints output to the terminal using Python’s logging module.

log_epoch(epochs_df)[source]
class braindecode.experiments.loggers.TensorboardWriter(log_dir)[source]

Bases: braindecode.experiments.loggers.Logger

Logs all values for tensorboard visualiuzation using tensorboardX.

Parameters:log_dir (string) – Directory path to log the output to
log_epoch(epochs_df)[source]

braindecode.experiments.monitors module

class braindecode.experiments.monitors.MisclassMonitor(col_suffix='misclass', threshold_for_binary_case=None)[source]

Bases: object

Monitor the examplewise misclassification rate.

Parameters:
  • col_suffix (str, optional) – Name of the column in the monitoring output.
  • threshold_for_binary_case (bool, optional) – In case of binary classification with only one output prediction per target, define the threshold for separating the classes, i.e. 0.5 for sigmoid outputs, or np.log(0.5) for log sigmoid outputs
monitor_epoch()[source]
monitor_set(setname, all_preds, all_losses, all_batch_sizes, all_targets, dataset)[source]
braindecode.experiments.monitors.compute_pred_labels_from_trial_preds(all_preds, threshold_for_binary_case=None)[source]
class braindecode.experiments.monitors.AveragePerClassMisclassMonitor(col_suffix='misclass')[source]

Bases: object

Compute average of misclasses per class, useful if classes are highly imbalanced.

Parameters:col_suffix (str) – Name of the column in the monitoring output.
monitor_epoch()[source]
monitor_set(setname, all_preds, all_losses, all_batch_sizes, all_targets, dataset)[source]
class braindecode.experiments.monitors.LossMonitor[source]

Bases: object

Monitor the examplewise loss.

monitor_epoch()[source]
monitor_set(setname, all_preds, all_losses, all_batch_sizes, all_targets, dataset)[source]
class braindecode.experiments.monitors.CroppedTrialMisclassMonitor(input_time_length=None)[source]

Bases: object

Compute trialwise misclasses from predictions for crops.

Parameters:input_time_length (int) – Temporal length of one input to the model.
monitor_epoch()[source]
monitor_set(setname, all_preds, all_losses, all_batch_sizes, all_targets, dataset)[source]

Assuming one hot encoding for now

braindecode.experiments.monitors.compute_trial_labels_from_crop_preds(all_preds, input_time_length, X)[source]

Compute predicted trial labels from arrays of crop predictions

Parameters:
  • all_preds (list of 2darrays (classes x time)) – All predictions for the crops.
  • input_time_length (int) – Temporal length of one input to the model.
  • X (ndarray) – Input tensor the crops were taken from.
Returns:

pred_labels_per_trial – Predicted label for each trial.

Return type:

1darray

braindecode.experiments.monitors.compute_preds_per_trial_from_crops(all_preds, input_time_length, X)[source]

Compute predictions per trial from predictions for crops.

Parameters:
  • all_preds (list of 2darrays (classes x time)) – All predictions for the crops.
  • input_time_length (int) – Temporal length of one input to the model.
  • X (ndarray) – Input tensor the crops were taken from.
Returns:

preds_per_trial – Predictions for each trial, without overlapping predictions.

Return type:

list of 2darrays (classes x time)

braindecode.experiments.monitors.compute_preds_per_trial_from_n_preds_per_trial(all_preds, n_preds_per_trial)[source]

Compute predictions per trial from predictions for crops.

Parameters:
  • all_preds (list of 2darrays (classes x time)) – All predictions for the crops.
  • input_time_length (int) – Temporal length of one input to the model.
  • n_preds_per_trial (list of int) – Number of predictions for each trial.
Returns:

preds_per_trial – Predictions for each trial, without overlapping predictions.

Return type:

list of 2darrays (classes x time)

class braindecode.experiments.monitors.RuntimeMonitor[source]

Bases: object

Monitor the runtime of each epoch.

First epoch will have runtime 0.

monitor_epoch()[source]
monitor_set(setname, all_preds, all_losses, all_batch_sizes, all_targets, dataset)[source]

braindecode.experiments.stopcriteria module

class braindecode.experiments.stopcriteria.MaxEpochs(max_epochs)[source]

Bases: object

Stop when given number of epochs reached:

Parameters:max_epochs (int)
should_stop(epochs_df)[source]
class braindecode.experiments.stopcriteria.Or(stop_criteria)[source]

Bases: object

Stop when one of the given stop criteria is triggered.

Parameters:stop_criteria (iterable of stop criteria objects)
should_stop(epochs_df)[source]
was_triggered(criterion)[source]

Return if given criterion was triggered in the last call to should stop.

Parameters:criterion (stop criterion)
Returns:triggered
Return type:bool
class braindecode.experiments.stopcriteria.And(stop_criteria)[source]

Bases: object

Stop when all of the given stop criteria are triggered.

Parameters:stop_criteria (iterable of stop criteria objects)
should_stop(epochs_df)[source]
was_triggered(criterion)[source]

Return if given criterion was triggered in the last call to should stop.

Parameters:criterion (stop criterion)
Returns:triggered
Return type:bool
class braindecode.experiments.stopcriteria.NoDecrease(column_name, num_epochs, min_decrease=1e-06)[source]

Bases: object

Stops if there is no decrease on a given monitor channel for given number of epochs.

Parameters:
  • column_name (str) – Name of column to monitor for decrease.
  • num_epochs (str) – Number of epochs to wait before stopping when there is no decrease.
  • min_decrease (float, optional) – Minimum relative decrease that counts as a decrease. E.g. 0.1 means only 10% decreases count as a decrease and reset the counter.
should_stop(epochs_df)[source]
class braindecode.experiments.stopcriteria.ColumnBelow(column_name, target_value)[source]

Bases: object

Stops if the given column is below the given value.

Parameters:
  • column_name (str) – Name of column to monitor.
  • target_value (float) – When column decreases below this value, criterion will say to stop.
should_stop(epochs_df)[source]