bob.med.tb.engine.trainer¶
Functions
|
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced. |
|
Check the device type and the availability of GPU. |
|
Process the checkpointer, save the final model and keep track of the best model. |
|
Creation of the logfile fields that will appear in the logfile. |
|
Fits a CNN model using supervised learning and save it to disk. |
|
Save a little summary of the model in a txt file. |
Save the static information in a csv file. |
|
|
Context manager to turn ON/OFF model evaluation |
|
Trains the model for a single epoch (through all batches) |
|
Processes input samples and returns loss (scalar) |
|
Write log info in trainlog.csv |
- bob.med.tb.engine.trainer.torch_evaluation(model)[source]¶
Context manager to turn ON/OFF model evaluation
This context manager will turn evaluation mode ON on entry and turn it OFF when exiting the
with
statement block.- Parameters
model (
torch.nn.Module
) – Network- Yields
model (
torch.nn.Module
) – Network
- bob.med.tb.engine.trainer.check_gpu(device)[source]¶
Check the device type and the availability of GPU.
- Parameters
device (
torch.device
) – device to use
- bob.med.tb.engine.trainer.save_model_summary(output_folder, model)[source]¶
Save a little summary of the model in a txt file.
- Parameters
output_folder (str) – output path
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)
- Returns
r (str) – The model summary in a text format.
n (int) – The number of parameters of the model.
- bob.med.tb.engine.trainer.static_information_to_csv(static_logfile_name, device, n)[source]¶
Save the static information in a csv file.
- Parameters
static_logfile_name (str) – The static file name which is a join between the output folder and “constant.csv”
- bob.med.tb.engine.trainer.check_exist_logfile(logfile_name, arguments)[source]¶
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
- bob.med.tb.engine.trainer.create_logfile_fields(valid_loader, extra_valid_loaders, device)[source]¶
Creation of the logfile fields that will appear in the logfile.
- Parameters
valid_loader (
torch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. If set toNone
, then do not validate it.extra_valid_loaders (
list
oftorch.utils.data.DataLoader
) – To be used to validate the model, however does not affect automatic checkpointing. If set toNone
, or empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.device (
torch.device
) – device to use
- Returns
logfile_fields – The fields that will appear in trainlog.csv
- Return type
- bob.med.tb.engine.trainer.train_epoch(loader, model, optimizer, device, criterion, batch_chunk_count)[source]¶
Trains the model for a single epoch (through all batches)
- Parameters
loader (
torch.utils.data.DataLoader
) – To be used to train the modelmodel (
torch.nn.Module
) – Network (e.g. driu, hed, unet)optimizer (
torch.optim
) –device (
torch.device
) – device to usecriterion (
torch.nn.modules.loss._Loss
) –batch_chunk_count (int) – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case. To better understand gradient accumulation, read https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch.
- Returns
loss – A floating-point value corresponding the weighted average of this epoch’s loss
- Return type
- bob.med.tb.engine.trainer.validate_epoch(loader, model, device, criterion, pbar_desc)[source]¶
Processes input samples and returns loss (scalar)
- Parameters
loader (
torch.utils.data.DataLoader
) – To be used to validate the modelmodel (
torch.nn.Module
) – Network (e.g. driu, hed, unet)optimizer (
torch.optim
) –device (
torch.device
) – device to usecriterion (
torch.nn.modules.loss._Loss
) – loss functionpbar_desc (str) – A string for the progress bar descriptor
- Returns
loss – A floating-point value corresponding the weighted average of this epoch’s loss
- Return type
- bob.med.tb.engine.trainer.checkpointer_process(checkpointer, checkpoint_period, valid_loss, lowest_validation_loss, arguments, epoch, max_epoch)[source]¶
Process the checkpointer, save the final model and keep track of the best model.
- Parameters
checkpointer (
bob.med.tb.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsvalid_loss (float) – Current epoch validation loss
lowest_validation_loss (float) – Keeps track of the best (lowest) validation loss
arguments (dict) – start and end epochs
max_epoch (int) – end_potch
- Returns
lowest_validation_loss – The lowest validation loss currently observed
- Return type
- bob.med.tb.engine.trainer.write_log_info(epoch, current_time, eta_seconds, loss, valid_loss, extra_valid_losses, optimizer, logwriter, logfile, resource_data)[source]¶
Write log info in trainlog.csv
- Parameters
epoch (int) – Current epoch
current_time (float) – Current training time
eta_seconds (float) – estimated time-of-arrival taking into consideration previous epoch performance
loss (float) – Current epoch’s training loss
valid_loss (
float
, None) – Current epoch’s validation lossextra_valid_losses (
list
offloat
) – Validation losses from other validation datasets being currently trackedoptimizer (
torch.optim
) –logwriter (csv.DictWriter) – Dictionary writer that give the ability to write on the trainlog.csv
logfile (io.TextIOWrapper) –
resource_data (tuple) – Monitored resources at the machine (CPU and GPU)
- bob.med.tb.engine.trainer.run(model, data_loader, valid_loader, extra_valid_loaders, optimizer, criterion, checkpointer, checkpoint_period, device, arguments, output_folder, monitoring_interval, batch_chunk_count, criterion_valid)[source]¶
Fits a CNN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)data_loader (
torch.utils.data.DataLoader
) – To be used to train the modelvalid_loaders (
list
oftorch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. IfNone
, then do not validate it.extra_valid_loaders (
list
oftorch.utils.data.DataLoader
) – To be used to validate the model, however does not affect automatic checkpointing. If empty, then does not log anything else. Otherwise, an extra column with the loss of every dataset in this list is kept on the final training log.optimizer (
torch.optim
) –criterion (
torch.nn.modules.loss._Loss
) – loss functioncheckpointer (
bob.med.tb.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsdevice (
torch.device
) – device to usearguments (dict) – start and end epochs
output_folder (str) – output path
monitoring_interval (int, float) – interval, in seconds (or fractions), through which we should monitor resources during training.
batch_chunk_count (int) – If this number is different than 1, then each batch will be divided in this number of chunks. Gradients will be accumulated to perform each mini-batch. This is particularly interesting when one has limited RAM on the GPU, but would like to keep training with larger batches. One exchanges for longer processing times in this case.
criterion_valid (
torch.nn.modules.loss._Loss
) – specific loss function for the validation set