bob.med.tb.engine.trainer¶
Functions
|
Fits a CNN model using supervised learning and save it to disk. |
|
Context manager to turn ON/OFF model evaluation |
- bob.med.tb.engine.trainer.torch_evaluation(model)[source]¶
Context manager to turn ON/OFF model evaluation
This context manager will turn evaluation mode ON on entry and turn it OFF when exiting the
with
statement block.- Parameters
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)- Yields
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)
- bob.med.tb.engine.trainer.run(model, data_loader, valid_loader, optimizer, criterion, checkpointer, checkpoint_period, device, arguments, output_folder, criterion_valid=None)[source]¶
Fits a CNN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters
model (
torch.nn.Module
) – Network (e.g. pasa)data_loader (
torch.utils.data.DataLoader
) – To be used to train the modelvalid_loader (
torch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. If set toNone
, then do not validate it.optimizer (
torch.optim
) –criterion (
torch.nn.modules.loss._Loss
) – loss functioncheckpointer (
bob.med.tb.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsdevice (str) – device to use
'cpu'
orcuda:0
arguments (dict) – start and end epochs
output_folder (str) – output path
criterion_valid (
torch.nn.modules.loss._Loss
) – specific loss function for the validation set