Python API¶
Architectures¶
-
class
bob.learn.pytorch.architectures.
CASIANet
(num_cls, drop_rate=0.5)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the CASIA-Net CNN model.
This class implements the CNN described in: “Learning Face Representation From Scratch”, D. Yi, Z. Lei, S. Liao and S.z. Li, 2014
-
conv
¶ The output of the convolutional / maxpool layers
- Type
-
avgpool
¶ The output of the average pooling layer (used as embedding)
- Type
-
classifier
¶ The output of the last linear (logits)
- Type
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – The data to forward through the network- Returns
x – The last layer of the network
- Return type
-
-
class
bob.learn.pytorch.architectures.
CNN8
(num_cls, drop_rate=0.5)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the CNN8 model.
-
conv
¶ The output of the convolutional / maxpool layers
- Type
-
avgpool
¶ The output of the average pooling layer (used as embedding)
- Type
-
classifier
¶ The output of the last linear (logits)
- Type
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – The data to forward through the network- Returns
x – The last layer of the network
- Return type
-
-
class
bob.learn.pytorch.architectures.
ConditionalGAN_discriminator
(conditional_dim, channels=3, ngpu=1)[source]¶ Bases:
torch.nn.modules.module.Module
Class implementating the conditional GAN discriminator
-
main
¶ The sequential container
- Type
-
-
class
bob.learn.pytorch.architectures.
ConditionalGAN_generator
(noise_dim, conditional_dim, channels=3, ngpu=1)[source]¶ Bases:
torch.nn.modules.module.Module
Class implementating the conditional GAN generator
This network is introduced in the following publication: Mehdi Mirza, Simon Osindero: “Conditional Generative Adversarial Nets”
-
main
¶ The sequential container
- Type
-
-
class
bob.learn.pytorch.architectures.
ConvAutoencoder
(return_latent_embedding=False)[source]¶ Bases:
torch.nn.modules.module.Module
A class defining a simple convolutional autoencoder.
-
return_latent_embedding
¶ returns the encoder output if true, the reconstructed image otherwise.
- Type
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – x = self.encoder(x)- Returns
either the encoder output or the reconstructed image
- Return type
-
-
class
bob.learn.pytorch.architectures.
DCGAN_discriminator
(ngpu)[source]¶ Bases:
torch.nn.modules.module.Module
Class implementating the discriminator part of the Deeply Convolutional GAN
This network is introduced in the following publication: Alec Radford, Luke Metz, Soumith Chintala: “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”, ICLR 2016
and most of the code is based on: https://github.com/pytorch/examples/tree/master/dcgan
-
forward
(input)[source]¶ Forward function
- Parameters
input (
torch.Tensor
) –- Returns
the output of the generator (i.e. an image)
- Return type
-
-
class
bob.learn.pytorch.architectures.
DCGAN_generator
(ngpu)[source]¶ Bases:
torch.nn.modules.module.Module
Class implementating the generator part of the Deeply Convolutional GAN
This network is introduced in the following publication: Alec Radford, Luke Metz, Soumith Chintala: “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”, ICLR 2016
and most of the code is based on: https://github.com/pytorch/examples/tree/master/dcgan
-
forward
(input)[source]¶ Forward function
- Parameters
input (
torch.Tensor
) –- Returns
the output of the generator (i.e. an image)
- Return type
-
-
class
bob.learn.pytorch.architectures.
DeepMSPAD
(pretrained=True, num_channels=4)[source]¶ Bases:
torch.nn.modules.module.Module
Deep multispectral PAD algorithm
The initialization uses Cross modality pre-training idea from the following paper:
Wang L, Xiong Y, Wang Z, Qiao Y, Lin D, Tang X, Van Gool L. Temporal segment networks: Towards good practices for deep action recognition. InEuropean conference on computer vision 2016 Oct 8 (pp. 20-36). Springer, Cham.
-
pretrained
¶ bool if set True loads the pretrained vgg16 model.
-
vgg
¶ torch.nn.Module
The VGG16 model
-
relu
¶ torch.nn.Module
ReLU activation
-
enc
¶ torch.nn.Module
Uses the layers for feature extraction
-
linear1
¶ torch.nn.Module
Fully connected layer
-
linear2
¶ torch.nn.Module
Fully connected layer
-
dropout
¶ torch.nn.Module
Dropout layer
-
sigmoid
¶ torch.nn.Module
Sigmoid activation
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – The data to forward through the network- Returns
x – The last layer of the network
- Return type
-
-
class
bob.learn.pytorch.architectures.
DeepPixBiS
(pretrained=True)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining Deep Pixelwise Binary Supervision for Face Presentation Attack Detection:
Reference: Anjith George and Sébastien Marcel. “Deep Pixel-wise Binary Supervision for Face Presentation Attack Detection.” In 2019 International Conference on Biometrics (ICB).IEEE, 2019.
-
pretrained
¶ If set to True uses the pretrained DenseNet model as the base. If set to False, the network will be trained from scratch. default: True
- Type
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
img (
torch.Tensor
) – The data to forward through the network. Expects RGB image of size 3x224x224- Returns
dec (
torch.Tensor
) – Binary map of size 1x14x14op (
torch.Tensor
) – Final binary score.
-
-
class
bob.learn.pytorch.architectures.
FASNet
(pretrained=True)[source]¶ Bases:
torch.nn.modules.module.Module
PyTorch Reimplementation of Lucena, Oeslle, et al. “Transfer learning using convolutional neural networks for face anti-spoofing.” International Conference Image Analysis and Recognition. Springer, Cham, 2017. Referenced from keras implementation: https://github.com/OeslleLucena/FASNet
-
pretrained
¶ bool if set True loads the pretrained vgg16 model.
-
vgg
¶ torch.nn.Module
The VGG16 model
-
relu
¶ torch.nn.Module
ReLU activation
-
enc
¶ torch.nn.Module
Uses the layers for feature extraction
-
linear1
¶ torch.nn.Module
Fully connected layer
-
linear2
¶ torch.nn.Module
Fully connected layer
-
dropout
¶ torch.nn.Module
Dropout layer
-
sigmoid
¶ torch.nn.Module
Sigmoid activation
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – The data to forward through the network- Returns
x – The last layer of the network
- Return type
-
-
class
bob.learn.pytorch.architectures.
LightCNN29
(block=<class 'bob.learn.pytorch.architectures.utils.resblock'>, layers=[1, 2, 3, 4], num_classes=79077)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the light CNN with 29 layers
This class implements the CNN described in: “A light CNN for deep face representation with noisy labels”, Wu, Xiang and He, Ran and Sun, Zhenan and Tan, Tieniu, IEEE Transactions on Information Forensics and Security, vol 13, issue 11, 2018
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – The data to forward through the network. Image of size 1x128x128- Returns
out (
torch.Tensor
) – class probabilitiesx (
torch.Tensor
) – Output of the penultimate layer (i.e. embedding)
-
-
class
bob.learn.pytorch.architectures.
LightCNN29v2
(block=<class 'bob.learn.pytorch.architectures.utils.resblock'>, layers=[1, 2, 3, 4], num_classes=79077)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the light CNN with 29 layers (version 2)
This class implements the CNN described in: “A light CNN for deep face representation with noisy labels”, Wu, Xiang and He, Ran and Sun, Zhenan and Tan, Tieniu, IEEE Transactions on Information Forensics and Security, vol 13, issue 11, 2018
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – The data to forward through the network. Image of size 1x128x128- Returns
out (
torch.Tensor
) – class probabilitiesx (
torch.Tensor
) – Output of the penultimate layer (i.e. embedding)
-
-
class
bob.learn.pytorch.architectures.
LightCNN9
(num_classes=79077)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the light CNN with 9 layers
This class implements the CNN described in: “A light CNN for deep face representation with noisy labels”, Wu, Xiang and He, Ran and Sun, Zhenan and Tan, Tieniu, IEEE Transactions on Information Forensics and Security, vol 13, issue 11, 2018
-
features
¶ The output of the convolutional / max layers
- Type
-
avgpool
¶ The output of the average pooling layer (used as embedding)
- Type
-
classifier
¶ The output of the last linear (logits)
- Type
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
x (
torch.Tensor
) – The data to forward through the network. Image of size 1x128x128- Returns
out (
torch.Tensor
) – class probabilitiesx (
torch.Tensor
) – Output of the penultimate layer (i.e. embedding)
-
-
class
bob.learn.pytorch.architectures.
MCCNN
(block=<class 'bob.learn.pytorch.architectures.utils.resblock'>, layers=[1, 2, 3, 4], num_channels=4, verbosity_level=2, use_sigmoid=True)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the MCCNN
This class implements the MCCNN for multi-channel PAD
-
module_dict
¶ A dictionary containing module names and torch.nn.Module elements as key, value pairs.
- Type
-
layer_dict
¶ Pytorch class containing the modules as a dictionary.
- Type
-
forward
(img)[source]¶ Propagate data through the network
- Parameters
img (
torch.Tensor
) – The data to forward through the network. Image of size num_channelsx128x128- Returns
output – score
- Return type
-
-
class
bob.learn.pytorch.architectures.
MCCNNv2
(block=<class 'bob.learn.pytorch.architectures.utils.resblock'>, layers=[1, 2, 3, 4], num_channels=4, adapted_layers='conv1-block1-group1-ffc', verbosity_level=2)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining the MCCNNv2 the difference from MCCNN is that it uses shared layers for layers which are not adapted. This avoids replicating shared layers.
-
module_dict
¶ A dictionary containing module names and torch.nn.Module elements as key, value pairs.
- Type
-
layer_dict
¶ Pytorch class containing the modules as a dictionary.
- Type
-
adapted_layers
¶ The layers to be adapted in training, they are to be separated by ‘-‘. Example: ‘conv1-block1-group1-ffc’; ‘ffc’ denotes final fully connected layers which are adapted in all the cases.
- Type
-
forward
(img)[source]¶ Propagate data through the network
- Parameters
img (
torch.Tensor
) – The data to forward through the network. Image of size num_channelsx128x128- Returns
output – score
- Return type
-
-
class
bob.learn.pytorch.architectures.
MCDeepPixBiS
(pretrained=True, num_channels=4)[source]¶ Bases:
torch.nn.modules.module.Module
The class defining Multi-Channel Deep Pixelwise Binary Supervision for Face Presentation Attack Detection:
This extends the following paper to multi-channel/ multi-spectral images with cross modal pretraining.
Reference: Anjith George and Sébastien Marcel. “Deep Pixel-wise Binary Supervision for Face Presentation Attack Detection.” In 2019 International Conference on Biometrics (ICB).IEEE, 2019.
The initialization uses Cross modality pre-training idea from the following paper:
Wang L, Xiong Y, Wang Z, Qiao Y, Lin D, Tang X, Van Gool L. Temporal segment networks: Towards good practices for deep action recognition. InEuropean conference on computer vision 2016 Oct 8 (pp. 20-36). Springer, Cham.
-
pretrained
¶ If set to True uses the pretrained DenseNet model as the base. If set to False, the network will be trained from scratch. default: True
- Type
-
forward
(x)[source]¶ Propagate data through the network
- Parameters
img (
torch.Tensor
) – The data to forward through the network. Expects Multi-channel images of size num_channelsx224x224- Returns
dec (
torch.Tensor
) – Binary map of size 1x14x14op (
torch.Tensor
) – Final binary score.
-
-
bob.learn.pytorch.architectures.
weights_init
(m)[source]¶ Initialize the weights
Initialize the weights in the different layers of the network.
- Parameters
m (
torch.nn.Conv2d
) – The layer to initialize
Datasets¶
-
class
bob.learn.pytorch.datasets.
CasiaDataset
(root_dir, transform=None, start_index=0)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Class representing the CASIA WebFace dataset
Note that in this class, two labels are provided with each image: identity and pose.
Pose labels have been automatically inferred using the ROC face recognirion SDK from RankOne.
There are 13 pose labels, corresponding to cluster of 15 degrees, ranging from -90 degress (left profile) to 90 degrees (right profile)
-
transform
¶ The transform(s) to apply to the face images
- Type
torchvision.transforms
-
-
class
bob.learn.pytorch.datasets.
CasiaWebFaceDataset
(root_dir, transform=None, start_index=0)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Class representing the CASIA WebFace dataset
Note that here the only label is identity
-
transform
¶ The transform(s) to apply to the face images
- Type
torchvision.transforms
-
-
class
bob.learn.pytorch.datasets.
ChannelSelect
(selected_channels=[0, 1, 2, 3])[source]¶ Bases:
object
Subselects or re-orders channels in a multi-channel image. Expects a numpy.ndarray as input with size HxWxnum_channels and returns an image with size HxWxlen(selected_channels), where the last dimension is subselected using the indexes in the list selected_channels.
-
img
¶ A multi channel image, HxWxnum_channels
- Type
-
-
class
bob.learn.pytorch.datasets.
ConcatDataset
(datasets)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Class to concatenate two or more datasets for DR-GAN training
Parameters
- datasets: list
The list of datasets (as torch.utils.data.Dataset)
-
class
bob.learn.pytorch.datasets.
DataFolder
(data_folder, transform=None, extension='.hdf5', bob_hldi_instance=None, hldi_type='pad', groups=['train', 'dev', 'eval'], protocol='grandtest', purposes=['real', 'attack'], allow_missing_files=True, **kwargs)[source]¶ Bases:
torch.utils.data.dataset.Dataset
A generic data loader compatible with Bob High Level Database Interfaces (HLDI). Only HLDI’s of
bob.pad.face
are currently supported.The basic functionality is composed of two steps: load the data from hdf5 file, and transform it using user defined transformation function.
Two types of user defined transformations are supported:
1. An instance of
Compose
transformation class fromtorchvision
package.2. A custom transformation function, which takes numpy.ndarray as input, and returns a transformed Tensor. The dimensionality of the output tensor must match the format expected by the network to be trained.
Note: if no special transformation is needed, the
transform
must at least convert an input numpy array to Tensor.-
data_folder
¶ A directory containing the training data. Note, that the training data must be stored as a FrameContainers written to the hdf5 files. Other formats are currently not supported.
- Type
-
transform
¶ A function
transform
takes an input numpy.ndarray sample/image, and returns a transformed version as a Tensor. Default: None.- Type
-
extension
¶ Extension of the data files. Default: “.hdf5”. Note: this is the only extension supported at the moment.
- Type
-
bob_hldi_instance
¶ An instance of the HLDI interface. Only HLDI’s of bob.pad.face are currently supported.
- Type
-
hldi_type
¶ String defining the type of the HLDI. Default: “pad”. Note: this is the only option currently supported.
- Type
-
groups
¶ The groups for which the clients should be returned. Usually, groups are one or more elements of [‘train’, ‘dev’, ‘eval’]. Default: [‘train’, ‘dev’, ‘eval’].
-
purposes
¶ The purposes for which File objects should be retrieved. Usually it is either ‘real’ or ‘attack’. Default: [‘real’, ‘attack’].
-
-
class
bob.learn.pytorch.datasets.
DataFolderGeneric
(data_folder, transform=None, extension='.hdf5', bob_hldi_instance=None, hldi_type='pad', groups=['train', 'dev', 'eval'], protocol='grandtest', purposes=['real', 'attack'], allow_missing_files=True, custom_func=None, **kwargs)[source]¶ Bases:
torch.utils.data.dataset.Dataset
A generic data loader compatible with Bob High Level Database Interfaces (HLDI). Only HLDI’s of
bob.pad.face
are currently supported.The basic functionality is composed of two steps: load the data from hdf5 file, and transform it using user defined transformation function.
Two types of user defined transformations are supported:
1. An instance of
Compose
transformation class fromtorchvision
package.2. A custom transformation function, which takes numpy.ndarray as input, and returns a transformed Tensor. The dimensionality of the output tensor must match the format expected by the network to be trained.
Note: if no special transformation is needed, the
transform
must at least convert an input numpy array to Tensor.-
data_folder
¶ A directory containing the training data. Note, that the training data must be stored as a FrameContainers written to the hdf5 files. Other formats are currently not supported.
- Type
-
transform
¶ A function
transform
takes an input numpy.ndarray sample/image, and returns a transformed version as a Tensor. Default: None.- Type
-
extension
¶ Extension of the data files. Default: “.hdf5”. Note: this is the only extension supported at the moment.
- Type
-
bob_hldi_instance
¶ An instance of the HLDI interface. Only HLDI’s of bob.pad.face are currently supported.
- Type
-
hldi_type
¶ String defining the type of the HLDI. Default: “pad”. Note: this is the only option currently supported.
- Type
-
groups
¶ The groups for which the clients should be returned. Usually, groups are one or more elements of [‘train’, ‘dev’, ‘eval’]. Default: [‘train’, ‘dev’, ‘eval’].
-
purposes
¶ The purposes for which File objects should be retrieved. Usually it is either ‘real’ or ‘attack’. Default: [‘real’, ‘attack’].
-
-
class
bob.learn.pytorch.datasets.
FaceCropAlign
(face_size, rgb_output_flag=False, use_face_alignment=True, alignment_type='lightcnn', face_detection_method='mtcnn')[source]¶ Bases:
object
Wrapper to the FaceCropAlign of bob.pad.face preprocessor
-
class
bob.learn.pytorch.datasets.
FaceCropper
(cropped_height, cropped_width, color_channel='rgb')[source]¶ Bases:
object
Class to crop a face, based on eyes position
-
class
bob.learn.pytorch.datasets.
RandomHorizontalFlipImage
(p=0.5)[source]¶ Bases:
object
Flips the image horizontally, works on numpy arrays.
Scripts¶
Trainers¶
-
class
bob.learn.pytorch.trainers.
CNNTrainer
(network, batch_size=64, use_gpu=False, verbosity_level=2, num_classes=2)[source]¶ Bases:
object
Class to train a CNN
-
network
¶ The network to train
- Type
-
load_and_initialize_model
(model_filename)[source]¶ Loads and initialize a model
- Parameters
model_filename (str) –
-
train
(dataloader, n_epochs=20, learning_rate=0.01, output_dir='out', model=None)[source]¶ Performs the training.
- Parameters
dataloader (
torch.utils.data.DataLoader
) – The dataloader for your datan_epochs (int) – The number of epochs you would like to train for
learning_rate (float) – The learning rate for SGD optimizer.
output_dir (str) – The directory where you would like to save models
-
-
class
bob.learn.pytorch.trainers.
ConditionalGANTrainer
(netG, netD, image_size, batch_size=64, noise_dim=100, conditional_dim=13, use_gpu=False, verbosity_level=2)[source]¶ Bases:
object
Class to train a Conditional GAN
-
generator
¶ The generator network
- Type
-
discriminator
¶ The discriminator network
- Type
-
fixed_noise
¶ The fixed input noise to the generator.
- Type
-
fixed_one_hot
¶ The set of fixed one-hot encoded conditioning variable
- Type
-
criterion
¶ The binary cross-entropy loss
- Type
-
train
(dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out')[source]¶ trains the Conditional GAN.
- Parameters
dataloader (
torch.utils.data.DataLoader
) – The dataloader for your datan_epochs (int) – The number of epochs you would like to train for
learning_rate (float) – The learning rate for Adam optimizer
beta1 (float) – The beta1 for Adam optimizer
output_dir (str) – The directory where you would like to output images and models
-
-
class
bob.learn.pytorch.trainers.
DCGANTrainer
(netG, netD, batch_size=64, noise_dim=100, use_gpu=False, verbosity_level=2)[source]¶ Bases:
object
Class to train a DCGAN
-
netG
¶ The generator network
- Type
-
netD
¶ The discriminator network
- Type
-
input
¶ The input image
- Type
-
noise
¶ The input noise to the generator
- Type
-
fixed_noise
¶ The fixed input noise to the generator. Used for generating images to save.
- Type
-
label
¶ label for real/fake images.
- Type
-
criterion
¶ The binary cross-entropy loss
- Type
-
train
(dataloader, n_epochs=10, learning_rate=0.0002, beta1=0.5, output_dir='out')[source]¶ trains the DCGAN.
- Parameters
dataloader (
torch.utils.data.DataLoader
) – The dataloader for your datan_epochs (int) – The number of epochs you would like to train for
learning_rate (float) – The learning rate for Adam optimizer
beta1 (float) – The beta1 for Adam optimizer
output_dir (str) – The directory where you would like to output images and models
-
-
class
bob.learn.pytorch.trainers.
FASNetTrainer
(network, batch_size=64, use_gpu=False, verbosity_level=2, tf_logdir='tf_logs', do_crossvalidation=False)[source]¶ Bases:
object
Class to train the MCCNN
-
network
¶ The network to train
- Type
-
load_model
(model_filename)[source]¶ Loads an existing model
- Parameters
model_file (str) – The filename of the model to load
- Returns
start_epoch (int) – The epoch to start with
start_iteration (int) – The iteration to start with
losses (list(float)) – The list of losses from previous training
-
train
(dataloader, n_epochs=25, learning_rate=0.0001, output_dir='out', model=None)[source]¶ Performs the training.
- Parameters
dataloader (
torch.utils.data.DataLoader
) – The dataloader for your datan_epochs (int) – The number of epochs you would like to train for
learning_rate (float) – The learning rate for Adam optimizer.
output_dir (str) – The directory where you would like to save models
model (str) – The path to a pretrained model file to start training from; this is the PAD model; not the LightCNN model
-
-
class
bob.learn.pytorch.trainers.
GenericTrainer
(network, optimizer, compute_loss, learning_rate=0.0001, device='cpu', verbosity_level=2, tf_logdir='tf_logs', do_crossvalidation=False, save_interval=5)[source]¶ Bases:
object
Class to train a generic NN; all the parameters are provided in configs
-
network
¶ The network to train
- Type
-
optimizer
¶ Optimizer object to be used. Initialized in the config file.
-
load_model
(model_filename)[source]¶ Loads an existing model
- Parameters
model_file (str) – The filename of the model to load
- Returns
start_epoch (int) – The epoch to start with
start_iteration (int) – The iteration to start with
losses (list(float)) – The list of losses from previous training
-
train
(dataloader, n_epochs=25, output_dir='out', model=None)[source]¶ Performs the training.
- Parameters
dataloader (
torch.utils.data.DataLoader
) – The dataloader for your datan_epochs (int) – The number of epochs you would like to train for
learning_rate (float) – The learning rate for Adam optimizer.
output_dir (str) – The directory where you would like to save models
model (str) – The path to a pretrained model file to start training from; this is the PAD model; not the LightCNN model
-
-
class
bob.learn.pytorch.trainers.
MCCNNTrainer
(network, batch_size=64, use_gpu=False, adapted_layers='conv1-block1-group1-ffc', adapt_reference_channel=False, verbosity_level=2, tf_logdir='tf_logs', do_crossvalidation=False)[source]¶ Bases:
object
Class to train the MCCNN
-
network
¶ The network to train
- Type
-
load_model
(model_filename)[source]¶ Loads an existing model
- Parameters
model_file (str) – The filename of the model to load
- Returns
start_epoch (int) – The epoch to start with
start_iteration (int) – The iteration to start with
losses (list(float)) – The list of losses from previous training
-
train
(dataloader, n_epochs=25, learning_rate=0.0001, output_dir='out', model=None)[source]¶ Performs the training.
- Parameters
dataloader (
torch.utils.data.DataLoader
) – The dataloader for your datan_epochs (int) – The number of epochs you would like to train for
learning_rate (float) – The learning rate for Adam optimizer.
output_dir (str) – The directory where you would like to save models
model (str) – The path to a pretrained model file to start training from; this is the PAD model; not the LightCNN model
-