bob.ip.binseg.engine.trainer¶
Functions
|
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced. |
|
Check the device type and the availability of GPU. |
|
Process the checkpointer, save the final model and keep track of the best model. |
|
Creation of the logfile fields that will appear in the logfile. |
|
Fits an FCN model using supervised learning and save it to disk. |
|
Save a little summary of the model in a txt file. |
Save the static information in a csv file. |
|
|
Context manager to turn ON/OFF model evaluation |
|
Processing the training inputs (Images, ground truth, masks) and apply the backprogration to update the training losses. |
|
Processing the validation inputs (Images, ground truth, masks) and update validation losses. |
|
Write log info in trainlog.csv |
- bob.ip.binseg.engine.trainer.torch_evaluation(model)[source]¶
Context manager to turn ON/OFF model evaluation
This context manager will turn evaluation mode ON on entry and turn it OFF when exiting the
with
statement block.- Parameters
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)- Yields
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)
- bob.ip.binseg.engine.trainer.check_gpu(device)[source]¶
Check the device type and the availability of GPU.
- Parameters
device (
torch.device
) – device to use
- bob.ip.binseg.engine.trainer.save_model_summary(output_folder, model)[source]¶
Save a little summary of the model in a txt file.
- Parameters
output_folder (str) – output path
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)
- Returns
r (str) – The model summary in a text format.
n (int) – The number of parameters of the model.
- bob.ip.binseg.engine.trainer.static_information_to_csv(static_logfile_name, device, n)[source]¶
Save the static information in a csv file.
- Parameters
static_logfile_name (str) – The static file name which is a join between the output folder and “constant.csv”
- bob.ip.binseg.engine.trainer.check_exist_logfile(logfile_name, arguments)[source]¶
Check existance of logfile (trainlog.csv), If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
- bob.ip.binseg.engine.trainer.create_logfile_fields(valid_loader, device)[source]¶
Creation of the logfile fields that will appear in the logfile.
- Parameters
valid_loader (
torch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. If set toNone
, then do not validate it.device (
torch.device
) – device to use
- Returns
logfile_fields – The fields that will appear in trainlog.csv
- Return type
- bob.ip.binseg.engine.trainer.train_sample_process(samples, model, optimizer, losses, device, criterion)[source]¶
Processing the training inputs (Images, ground truth, masks) and apply the backprogration to update the training losses.
- Parameters
samples (list) –
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)optimizer (
torch.optim
) –losses (
bob.ip.binseg.utils.measure.SmoothedValue
) –device (
torch.device
) – device to usecriterion (
torch.nn.modules.loss._Loss
) – loss function
- Returns
optimizer (
torch.optim
)
- bob.ip.binseg.engine.trainer.valid_sample_process(samples, model, valid_losses, device, criterion)[source]¶
Processing the validation inputs (Images, ground truth, masks) and update validation losses.
- Parameters
samples (list) –
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)optimizer (
torch.optim
) –valid_losses (
bob.ip.binseg.utils.measure.SmoothedValue
) –device (
torch.device
) – device to usecriterion (
torch.nn.modules.loss._Loss
) – loss function
- Returns
valid_losses
- Return type
- bob.ip.binseg.engine.trainer.checkpointer_process(checkpointer, checkpoint_period, valid_losses, lowest_validation_loss, arguments, epoch, max_epoch)[source]¶
Process the checkpointer, save the final model and keep track of the best model.
- Parameters
checkpointer (
bob.ip.binseg.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsvalid_losses (
bob.ip.binseg.utils.measure.SmoothedValue
) –lowest_validation_loss (float) – Keep track of the best (lowest) validation loss
arguments (dict) – start and end epochs
max_epoch (int) – end_potch
- bob.ip.binseg.engine.trainer.write_log_info(epoch, current_time, eta_seconds, losses, valid_losses, optimizer, logwriter, logfile, device)[source]¶
Write log info in trainlog.csv
- Parameters
epoch (int) – Current epoch
current_time (float) – Current training time
eta_seconds (float) – estimated time-of-arrival taking into consideration previous epoch performance
losses (
bob.ip.binseg.utils.measure.SmoothedValue
) –valid_losses (
bob.ip.binseg.utils.measure.SmoothedValue
) –optimizer (
torch.optim
) –logwriter (csv.DictWriter) – Dictionary writer that give the ability to write on the trainlog.csv
logfile (io.TextIOWrapper) –
device (
torch.device
) – device to use
- bob.ip.binseg.engine.trainer.run(model, data_loader, valid_loader, optimizer, criterion, scheduler, checkpointer, checkpoint_period, device, arguments, output_folder)[source]¶
Fits an FCN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)data_loader (
torch.utils.data.DataLoader
) – To be used to train the modelvalid_loader (
torch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. If set toNone
, then do not validate it.optimizer (
torch.optim
) –criterion (
torch.nn.modules.loss._Loss
) – loss functionscheduler (
torch.optim
) – learning rate schedulercheckpointer (
bob.ip.binseg.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsdevice (
torch.device
) – device to usearguments (dict) – start and end epochs
output_folder (str) – output path