bob.ip.binseg.engine.ssltrainer¶
Functions
|
Calculate the average predictions by 2 augmentations: horizontal and vertical flips |
|
slowly ramp-up |
|
Applies mix up as described in [MIXMATCH_19]. |
|
Fits an FCN model using semi-supervised learning and saves it to disk. |
|
|
|
slowly ramp-up |
-
bob.ip.binseg.engine.ssltrainer.
mix_up
(alpha, input, target, unlabelled_input, unlabled_target)[source]¶ Applies mix up as described in [MIXMATCH_19].
- Parameters
alpha (float) –
input (
torch.Tensor
) –target (
torch.Tensor
) –unlabelled_input (
torch.Tensor
) –unlabled_target (
torch.Tensor
) –
- Returns
- Return type
-
bob.ip.binseg.engine.ssltrainer.
square_rampup
(current, rampup_length=16)[source]¶ slowly ramp-up
lambda_u
-
bob.ip.binseg.engine.ssltrainer.
linear_rampup
(current, rampup_length=16)[source]¶ slowly ramp-up
lambda_u
-
bob.ip.binseg.engine.ssltrainer.
guess_labels
(unlabelled_images, model)[source]¶ Calculate the average predictions by 2 augmentations: horizontal and vertical flips
- Parameters
unlabelled_images (
torch.Tensor
) –[n,c,h,w]
target (
torch.Tensor
) –
- Returns
shape –
[n,c,h,w]
- Return type
-
bob.ip.binseg.engine.ssltrainer.
run
(model, data_loader, valid_loader, optimizer, criterion, scheduler, checkpointer, checkpoint_period, device, arguments, output_folder, rampup_length)[source]¶ Fits an FCN model using semi-supervised learning and saves it to disk.
This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training.
- Parameters
model (
torch.nn.Module
) – Network (e.g. driu, hed, unet)data_loader (
torch.utils.data.DataLoader
) – To be used to train the modelvalid_loader (
torch.utils.data.DataLoader
) – To be used to validate the model and enable automatic checkpointing. If set toNone
, then do not validate it.optimizer (
torch.optim
) –criterion (
torch.nn.modules.loss._Loss
) – loss functionscheduler (
torch.optim
) – learning rate schedulercheckpointer (
bob.ip.binseg.utils.checkpointer.Checkpointer
) – checkpointer implementationcheckpoint_period (int) – save a checkpoint every
n
epochs. If set to0
(zero), then do not save intermediary checkpointsdevice (str) – device to use
'cpu'
orcuda:0
arguments (dict) – start and end epochs
output_folder (str) – output path
rampup_length (int) – rampup epochs