braindecode.experiments package¶
Convenience classes for experiments, including monitoring and stop criteria.
Submodules¶
braindecode.experiments.experiment module¶
-
class
braindecode.experiments.experiment.
RememberBest
(column_name)[source]¶ Bases:
object
Class to remember and restore the parameters of the model and the parameters of the optimizer at the epoch with the best performance.
Parameters: column_name (str) – The lowest value in this column should indicate the epoch with the best performance (e.g. misclass might make sense). Variables: best_epoch (int) – Index of best epoch -
remember_epoch
(epochs_df, model, optimizer)[source]¶ Remember this epoch: Remember parameter values in case this epoch has the best performance so far.
Parameters: - epochs_df (pandas.Dataframe) – Dataframe containing the column column_name with which performance is evaluated.
- model (torch.nn.Module)
- optimizer (torch.optim.Optimizer)
-
reset_to_best_model
(epochs_df, model, optimizer)[source]¶ Reset parameters to parameters at best epoch and remove rows after best epoch from epochs dataframe.
Modifies parameters of model and optimizer, changes epochs_df in-place.
Parameters: - epochs_df (pandas.Dataframe)
- model (torch.nn.Module)
- optimizer (torch.optim.Optimizer)
-
-
class
braindecode.experiments.experiment.
Experiment
(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint, monitors, stop_criterion, remember_best_column, run_after_early_stop, model_loss_function=None, batch_modifier=None, cuda=True, pin_memory=False, do_early_stop=True, reset_after_second_run=False, log_0_epoch=True, loggers=('print', ))[source]¶ Bases:
object
Class that performs one experiment on training, validation and test set.
It trains as follows:
- Train on training set until a given stop criterion is fulfilled
- Reset to the best epoch, i.e. reset parameters of the model and the optimizer to the state at the best epoch (“best” according to a given criterion)
- Continue training on the combined training + validation set until the loss on the validation set is as low as it was on the best epoch for the training set. (or until the ConvNet was trained twice as many epochs as the best epoch to prevent infinite training)
Parameters: - model (torch.nn.Module)
- train_set (
SignalAndTarget
) - valid_set (
SignalAndTarget
) - test_set (
SignalAndTarget
) - iterator (iterator object)
- loss_function (function) – Function mapping predictions and targets to a loss: (predictions: torch.autograd.Variable, targets:torch.autograd.Variable) -> loss: torch.autograd.Variable
- optimizer (torch.optim.Optimizer)
- model_constraint (object) – Object with apply function that takes model and constraints its parameters. None for no constraint.
- monitors (list of objects) – List of objects with monitor_epoch and monitor_set method, should monitor the traning progress.
- stop_criterion (object) – Object with should_stop method, that takes in monitoring dataframe and returns if training should stop:
- remember_best_column (str) – Name of column to use for storing parameters of best model. Lowest value should indicate best performance in this column.
- run_after_early_stop (bool) – Whether to continue running after early stop
- model_loss_function (function, optional) – Function (model -> loss) to add a model loss like L2 regularization. Note that this loss is not accounted for in monitoring at the moment.
- batch_modifier (object, optional) – Object with modify method, that can change the batch, e.g. for data augmentation
- cuda (bool, optional) – Whether to use cuda.
- pin_memory (bool, optional) – Whether to pin memory of inputs and targets of batch.
- do_early_stop (bool) – Whether to do an early stop at all. If true, reset to best model even in case experiment does not run after early stop.
- reset_after_second_run (bool) – If true, reset to best model when second run did not find a valid loss below or equal to the best train loss of first run.
- log_0_epoch (bool) – Whether to compute monitor values and log them before the start of training.
- loggers (list of
Logger
) – How to show computed metrics.
Variables: epochs_df (pandas.DataFrame) – Monitoring values for all epochs.
-
run_until_first_stop
()[source]¶ Run training and evaluation using only training set for training until stop criterion is fulfilled.
-
run_until_second_stop
()[source]¶ Run training and evaluation using combined training + validation set for training.
Runs until loss on validation set decreases below loss on training set of best epoch or until as many epochs trained after as before first stop.
-
run_until_stop
(datasets, remember_best)[source]¶ Run training and evaluation on given datasets until stop criterion is fulfilled.
Parameters: - datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to
SignalAndTarget
objects. - remember_best (bool) – Whether to remember parameters at best epoch.
- datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to
-
run_one_epoch
(datasets, remember_best)[source]¶ Run training and evaluation on given datasets for one epoch.
Parameters: - datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to
SignalAndTarget
objects. - remember_best (bool) – Whether to remember parameters if this epoch is best epoch.
- datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to
-
train_batch
(inputs, targets)[source]¶ Train on given inputs and targets.
Parameters: - inputs (torch.autograd.Variable)
- targets (torch.autograd.Variable)
-
eval_on_batch
(inputs, targets)[source]¶ Evaluate given inputs and targets.
Parameters: - inputs (torch.autograd.Variable)
- targets (torch.autograd.Variable)
Returns: - predictions (torch.autograd.Variable)
- loss (torch.autograd.Variable)
-
monitor_epoch
(datasets)[source]¶ Evaluate one epoch for given datasets.
Stores results in epochs_df
Parameters: datasets (OrderedDict) – Dictionary with train, valid and test as str mapping to SignalAndTarget
objects.
braindecode.experiments.loggers module¶
-
class
braindecode.experiments.loggers.
Printer
[source]¶ Bases:
braindecode.experiments.loggers.Logger
Prints output to the terminal using Python’s logging module.
-
class
braindecode.experiments.loggers.
TensorboardWriter
(log_dir)[source]¶ Bases:
braindecode.experiments.loggers.Logger
Logs all values for tensorboard visualiuzation using tensorboardX.
Parameters: log_dir (string) – Directory path to log the output to
braindecode.experiments.monitors module¶
-
class
braindecode.experiments.monitors.
MisclassMonitor
(col_suffix='misclass', threshold_for_binary_case=None)[source]¶ Bases:
object
Monitor the examplewise misclassification rate.
Parameters: - col_suffix (str, optional) – Name of the column in the monitoring output.
- threshold_for_binary_case (bool, optional) – In case of binary classification with only one output prediction per target, define the threshold for separating the classes, i.e. 0.5 for sigmoid outputs, or np.log(0.5) for log sigmoid outputs
-
braindecode.experiments.monitors.
compute_pred_labels_from_trial_preds
(all_preds, threshold_for_binary_case=None)[source]¶
-
class
braindecode.experiments.monitors.
AveragePerClassMisclassMonitor
(col_suffix='misclass')[source]¶ Bases:
object
Compute average of misclasses per class, useful if classes are highly imbalanced.
Parameters: col_suffix (str) – Name of the column in the monitoring output.
-
class
braindecode.experiments.monitors.
LossMonitor
[source]¶ Bases:
object
Monitor the examplewise loss.
-
class
braindecode.experiments.monitors.
CroppedTrialMisclassMonitor
(input_time_length=None)[source]¶ Bases:
object
Compute trialwise misclasses from predictions for crops.
Parameters: input_time_length (int) – Temporal length of one input to the model.
-
braindecode.experiments.monitors.
compute_trial_labels_from_crop_preds
(all_preds, input_time_length, X)[source]¶ Compute predicted trial labels from arrays of crop predictions
Parameters: - all_preds (list of 2darrays (classes x time)) – All predictions for the crops.
- input_time_length (int) – Temporal length of one input to the model.
- X (ndarray) – Input tensor the crops were taken from.
Returns: pred_labels_per_trial – Predicted label for each trial.
Return type: 1darray
-
braindecode.experiments.monitors.
compute_preds_per_trial_from_crops
(all_preds, input_time_length, X)[source]¶ Compute predictions per trial from predictions for crops.
Parameters: - all_preds (list of 2darrays (classes x time)) – All predictions for the crops.
- input_time_length (int) – Temporal length of one input to the model.
- X (ndarray) – Input tensor the crops were taken from.
Returns: preds_per_trial – Predictions for each trial, without overlapping predictions.
Return type: list of 2darrays (classes x time)
-
braindecode.experiments.monitors.
compute_preds_per_trial_from_n_preds_per_trial
(all_preds, n_preds_per_trial)[source]¶ Compute predictions per trial from predictions for crops.
Parameters: - all_preds (list of 2darrays (classes x time)) – All predictions for the crops.
- input_time_length (int) – Temporal length of one input to the model.
- n_preds_per_trial (list of int) – Number of predictions for each trial.
Returns: preds_per_trial – Predictions for each trial, without overlapping predictions.
Return type: list of 2darrays (classes x time)
braindecode.experiments.stopcriteria module¶
-
class
braindecode.experiments.stopcriteria.
MaxEpochs
(max_epochs)[source]¶ Bases:
object
Stop when given number of epochs reached:
Parameters: max_epochs (int)
-
class
braindecode.experiments.stopcriteria.
Or
(stop_criteria)[source]¶ Bases:
object
Stop when one of the given stop criteria is triggered.
Parameters: stop_criteria (iterable of stop criteria objects)
-
class
braindecode.experiments.stopcriteria.
And
(stop_criteria)[source]¶ Bases:
object
Stop when all of the given stop criteria are triggered.
Parameters: stop_criteria (iterable of stop criteria objects)
-
class
braindecode.experiments.stopcriteria.
NoDecrease
(column_name, num_epochs, min_decrease=1e-06)[source]¶ Bases:
object
Stops if there is no decrease on a given monitor channel for given number of epochs.
Parameters: - column_name (str) – Name of column to monitor for decrease.
- num_epochs (str) – Number of epochs to wait before stopping when there is no decrease.
- min_decrease (float, optional) – Minimum relative decrease that counts as a decrease. E.g. 0.1 means only 10% decreases count as a decrease and reset the counter.
-
class
braindecode.experiments.stopcriteria.
ColumnBelow
(column_name, target_value)[source]¶ Bases:
object
Stops if the given column is below the given value.
Parameters: - column_name (str) – Name of column to monitor.
- target_value (float) – When column decreases below this value, criterion will say to stop.