1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import os
5import unittest
6from collections import OrderedDict
7from tempfile import TemporaryDirectory
8
9import torch
10
11from ..utils.checkpointer import Checkpointer
12
13
14class TestCheckpointer(unittest.TestCase):
15 def create_model(self):
16 return torch.nn.Sequential(
17 torch.nn.Linear(2, 3), torch.nn.Linear(3, 1)
18 )
19
20 def create_complex_model(self):
21 m = torch.nn.Module()
22 m.block1 = torch.nn.Module()
23 m.block1.layer1 = torch.nn.Linear(2, 3)
24 m.layer2 = torch.nn.Linear(3, 2)
25 m.res = torch.nn.Module()
26 m.res.layer2 = torch.nn.Linear(3, 2)
27
28 state_dict = OrderedDict()
29 state_dict["layer1.weight"] = torch.rand(3, 2)
30 state_dict["layer1.bias"] = torch.rand(3)
31 state_dict["layer2.weight"] = torch.rand(2, 3)
32 state_dict["layer2.bias"] = torch.rand(2)
33 state_dict["res.layer2.weight"] = torch.rand(2, 3)
34 state_dict["res.layer2.bias"] = torch.rand(2)
35
36 return m, state_dict
37
38 def test_from_last_checkpoint_model(self):
39 # test that loading works even if they differ by a prefix
40 trained_model = self.create_model()
41 fresh_model = self.create_model()
42 with TemporaryDirectory() as f:
43 checkpointer = Checkpointer(trained_model, path=f)
44 checkpointer.save("checkpoint_file")
45
46 # in the same folder
47 fresh_checkpointer = Checkpointer(fresh_model, path=f)
48 assert fresh_checkpointer.has_checkpoint()
49 assert fresh_checkpointer.last_checkpoint() == os.path.realpath(
50 os.path.join(f, "checkpoint_file.pth")
51 )
52 _ = fresh_checkpointer.load()
53
54 for trained_p, loaded_p in zip(
55 trained_model.parameters(), fresh_model.parameters()
56 ):
57 # different tensor references
58 assert id(trained_p) != id(loaded_p)
59 # same content
60 assert trained_p.equal(loaded_p)
61
62 def test_from_name_file_model(self):
63 # test that loading works even if they differ by a prefix
64 trained_model = self.create_model()
65 fresh_model = self.create_model()
66 with TemporaryDirectory() as f:
67 checkpointer = Checkpointer(trained_model, path=f)
68 checkpointer.save("checkpoint_file")
69
70 # on different folders
71 with TemporaryDirectory() as g:
72 fresh_checkpointer = Checkpointer(fresh_model, path=g)
73 assert not fresh_checkpointer.has_checkpoint()
74 assert fresh_checkpointer.last_checkpoint() == None
75 _ = fresh_checkpointer.load(
76 os.path.join(f, "checkpoint_file.pth")
77 )
78
79 for trained_p, loaded_p in zip(
80 trained_model.parameters(), fresh_model.parameters()
81 ):
82 # different tensor references
83 assert id(trained_p) != id(loaded_p)
84 # same content
85 assert trained_p.equal(loaded_p)