Python API¶
Architectures¶
-
class
bob.learn.pytorch.architectures.
CASIANet
(num_cls, drop_rate=0.5)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the CASIA-Net CNN model.
This class implements the CNN described in: “Learning Face Representation From Scratch”, D. Yi, Z. Lei, S. Liao and S.z. Li, 2014
-
num_classes
¶ int – The number of classes.
-
drop_rate
¶ float – The probability for dropout.
-
conv
¶ torch.nn.Module
– The output of the convolutional / maxpool layers
-
avgpool
¶ torch.nn.Module
– The output of the average pooling layer (used as embedding)
-
classifier
¶ torch.nn.Module
– The output of the last linear (logits)
-
forward
(x)[source]¶ Propagate data through the network
Parameters: x ( torch.Tensor
) – The data to forward through the networkReturns: x – The last layer of the network Return type: torch.Tensor
-
-
class
bob.learn.pytorch.architectures.
CNN8
(num_cls, drop_rate=0.5)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the CNN8 model.
-
num_classes
¶ int – The number of classes.
-
drop_rate
¶ float – The probability for dropout.
-
conv
¶ torch.nn.Module
– The output of the convolutional / maxpool layers
-
avgpool
¶ torch.nn.Module
– The output of the average pooling layer (used as embedding)
-
classifier
¶ torch.nn.Module
– The output of the last linear (logits)
-
forward
(x)[source]¶ Propagate data through the network
Parameters: x ( torch.Tensor
) – The data to forward through the networkReturns: x – The last layer of the network Return type: torch.Tensor
-
Datasets¶
-
class
bob.learn.pytorch.datasets.
CasiaDataset
(root_dir, frontal_only=False, transform=None, start_index=0)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Casia WebFace dataset.
Class representing the CASIA WebFace dataset
Parameters
- root-dir: path
- The path to the data
- frontal_only: boolean
- If you want to only use frontal faces
- transform: torchvision.transforms
- The transform(s) to apply to the face images
-
class
bob.learn.pytorch.datasets.
CasiaWebFaceDataset
(root_dir, transform=None, start_index=0)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Casia WebFace dataset (for CNN training).
Class representing the CASIA WebFace dataset
Parameters
- root-dir: path
- The path to the data
- frontal_only: boolean
- If you want to only use frontal faces
- transform: torchvision.transforms
- The transform(s) to apply to the face images
-
class
bob.learn.pytorch.datasets.
ConcatDataset
(datasets)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Class to concatenate two or more datasets for DR-GAN training
Parameters
- datasets: list
- The list of datasets (as torch.utils.data.Dataset)
-
class
bob.learn.pytorch.datasets.
FaceCropper
(cropped_height, cropped_width)[source]¶ Bases:
object
Class to crop a face, based on eyes position
Scripts¶
Trainers¶
-
class
bob.learn.pytorch.trainers.
CNNTrainer
(network, batch_size=64, use_gpu=False, verbosity_level=2)[source]¶ Bases:
object
Class to train a CNN
-
network
¶ torch.nn.Module
– The network to train
-
batch_size
¶ int – The size of your minibatch
-
use_gpu
¶ boolean – If you would like to use the gpu
-
verbosity_level
¶ int – The level of verbosity output to stdout
-
load_model
(model_filename)[source]¶ Loads an existing model
Parameters: model_file (str) – The filename of the model to load Returns: - start_epoch (int) – The epoch to start with
- start_iteration (int) – The iteration to start with
- losses (list(float)) – The list of losses from previous training
-
save_model
(output_dir, epoch=0, iteration=0, losses=None)[source]¶ Save the trained network
Parameters:
-
train
(dataloader, n_epochs=20, learning_rate=0.01, output_dir='out', model=None)[source]¶ Performs the training.
Parameters: - dataloader (
torch.utils.data.DataLoader
) – The dataloader for your data - n_epochs (int) – The number of epochs you would like to train for
- learning_rate (float) – The learning rate for SGD optimizer.
- output_dir (str) – The directory where you would like to save models
- dataloader (
-