1#!/usr/bin/env python
2# coding=utf-8
3
4from torch.utils.data.dataset import ConcatDataset
5
6def _maker(protocol):
7
8 if protocol == "default":
9 from ..montgomery import default as mc
10 from ..shenzhen import default as ch
11 elif protocol == "rgb":
12 from ..montgomery import rgb as mc
13 from ..shenzhen import rgb as ch
14 elif protocol == "fold_0":
15 from ..montgomery import fold_0 as mc
16 from ..shenzhen import fold_0 as ch
17 elif protocol == "fold_1":
18 from ..montgomery import fold_1 as mc
19 from ..shenzhen import fold_1 as ch
20 elif protocol == "fold_2":
21 from ..montgomery import fold_2 as mc
22 from ..shenzhen import fold_2 as ch
23 elif protocol == "fold_3":
24 from ..montgomery import fold_3 as mc
25 from ..shenzhen import fold_3 as ch
26 elif protocol == "fold_4":
27 from ..montgomery import fold_4 as mc
28 from ..shenzhen import fold_4 as ch
29 elif protocol == "fold_5":
30 from ..montgomery import fold_5 as mc
31 from ..shenzhen import fold_5 as ch
32 elif protocol == "fold_6":
33 from ..montgomery import fold_6 as mc
34 from ..shenzhen import fold_6 as ch
35 elif protocol == "fold_7":
36 from ..montgomery import fold_7 as mc
37 from ..shenzhen import fold_7 as ch
38 elif protocol == "fold_8":
39 from ..montgomery import fold_8 as mc
40 from ..shenzhen import fold_8 as ch
41 elif protocol == "fold_9":
42 from ..montgomery import fold_9 as mc
43 from ..shenzhen import fold_9 as ch
44 elif protocol == "fold_0_rgb":
45 from ..montgomery import fold_0_rgb as mc
46 from ..shenzhen import fold_0_rgb as ch
47 elif protocol == "fold_1_rgb":
48 from ..montgomery import fold_1_rgb as mc
49 from ..shenzhen import fold_1_rgb as ch
50 elif protocol == "fold_2_rgb":
51 from ..montgomery import fold_2_rgb as mc
52 from ..shenzhen import fold_2_rgb as ch
53 elif protocol == "fold_3_rgb":
54 from ..montgomery import fold_3_rgb as mc
55 from ..shenzhen import fold_3_rgb as ch
56 elif protocol == "fold_4_rgb":
57 from ..montgomery import fold_4_rgb as mc
58 from ..shenzhen import fold_4_rgb as ch
59 elif protocol == "fold_5_rgb":
60 from ..montgomery import fold_5_rgb as mc
61 from ..shenzhen import fold_5_rgb as ch
62 elif protocol == "fold_6_rgb":
63 from ..montgomery import fold_6_rgb as mc
64 from ..shenzhen import fold_6_rgb as ch
65 elif protocol == "fold_7_rgb":
66 from ..montgomery import fold_7_rgb as mc
67 from ..shenzhen import fold_7_rgb as ch
68 elif protocol == "fold_8_rgb":
69 from ..montgomery import fold_8_rgb as mc
70 from ..shenzhen import fold_8_rgb as ch
71 elif protocol == "fold_9_rgb":
72 from ..montgomery import fold_9_rgb as mc
73 from ..shenzhen import fold_9_rgb as ch
74
75 mc = mc.dataset
76 ch = ch.dataset
77
78 dataset = {}
79 dataset['__train__'] = ConcatDataset([mc["__train__"], ch["__train__"]])
80 dataset['train'] = ConcatDataset([mc["train"], ch["train"]])
81 dataset['__valid__'] = ConcatDataset([mc["__valid__"], ch["__valid__"]])
82 dataset['validation'] = ConcatDataset([mc["validation"], ch["validation"]])
83 dataset['test'] = ConcatDataset([mc["test"], ch["test"]])
84
85 return dataset