Source code for bob.ip.common.utils.checkpointer

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import os

import torch

logger = logging.getLogger(__name__)


[docs]class Checkpointer: """A simple pytorch checkpointer Parameters ---------- model : torch.nn.Module Network model, eventually loaded from a checkpointed file optimizer : :py:mod:`torch.optim`, Optional Optimizer scheduler : :py:mod:`torch.optim`, Optional Learning rate scheduler path : :py:class:`str`, Optional Directory where to save checkpoints. """ def __init__(self, model, optimizer=None, scheduler=None, path="."): self.model = model self.optimizer = optimizer self.scheduler = scheduler self.path = os.path.realpath(path)
[docs] def save(self, name, **kwargs): data = {} data["model"] = self.model.state_dict() if self.optimizer is not None: data["optimizer"] = self.optimizer.state_dict() if self.scheduler is not None: data["scheduler"] = self.scheduler.state_dict() data.update(kwargs) name = f"{name}.pth" outf = os.path.join(self.path, name) logger.info(f"Saving checkpoint to {outf}") torch.save(data, outf) with open(self._last_checkpoint_filename, "w") as f: f.write(name)
[docs] def load(self, f=None): """Loads model, optimizer and scheduler from file Parameters ========== f : :py:class:`str`, Optional Name of a file (absolute or relative to ``self.path``), that contains the checkpoint data to load into the model, and optionally into the optimizer and the scheduler. If not specified, loads data from current path. """ if f is None: f = self.last_checkpoint() if f is None: # no checkpoint could be found logger.warning("No checkpoint found (and none passed)") return {} # loads file data into memory logger.info(f"Loading checkpoint from {f}...") checkpoint = torch.load(f, map_location=torch.device("cpu")) # converts model entry to model parameters self.model.load_state_dict(checkpoint.pop("model")) if self.optimizer is not None: self.optimizer.load_state_dict(checkpoint.pop("optimizer")) if self.scheduler is not None: self.scheduler.load_state_dict(checkpoint.pop("scheduler")) return checkpoint
@property def _last_checkpoint_filename(self): return os.path.join(self.path, "last_checkpoint")
[docs] def has_checkpoint(self): return os.path.exists(self._last_checkpoint_filename)
[docs] def last_checkpoint(self): if self.has_checkpoint(): with open(self._last_checkpoint_filename, "r") as fobj: return os.path.join(self.path, fobj.read().strip()) return None