Coverage for src/deepdraw/configs/datasets/__init__.py: 100%
27 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-03-29 22:17 +0100
« prev ^ index » next coverage.py v7.4.2, created at 2024-03-29 22:17 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""Standard configurations for dataset setup."""
8from ...data.transforms import ColorJitter as _jitter
9from ...data.transforms import RandomHorizontalFlip as _hflip
10from ...data.transforms import RandomRotation as _rotation
11from ...data.transforms import RandomVerticalFlip as _vflip
13RANDOM_ROTATION = [_rotation()]
14"""Shared data augmentation based on random rotation only."""
17RANDOM_FLIP_JITTER = [_hflip(), _vflip(), _jitter()]
18"""Shared data augmentation transforms without random rotation."""
21def make_subset(samples, transforms, prefixes=[], suffixes=[]):
22 """Creates a new data set, applying transforms.
24 .. note::
26 This is a convenience function for our own dataset definitions inside
27 this module, guaranteeting homogenity between dataset definitions
28 provided in this package. It assumes certain strategies for data
29 augmentation that may not be translatable to other applications.
32 Parameters
33 ----------
35 samples : list
36 List of delayed samples
38 transforms : list
39 A list of transforms that needs to be applied to all samples in the set
41 prefixes : list
42 A list of data augmentation operations that needs to be applied
43 **before** the transforms above
45 suffixes : list
46 A list of data augmentation operations that needs to be applied
47 **after** the transforms above
50 Returns
51 -------
53 subset : :py:class:`deepdraw.data.utils.SampleListDataset`
54 A pre-formatted dataset that can be fed to one of our engines
55 """
57 from ...data.utils import SampleListDataset as wrapper
59 return wrapper(samples, prefixes + transforms + suffixes)
62def augment_subset(s, rotation_before=False):
63 """Creates a new subset set, **with data augmentation**
65 Typically, the transforms are chained to a default set of data augmentation
66 operations (random rotation, horizontal and vertical flips, and color
67 jitter), but a flag allows prefixing the rotation specially (useful for
68 some COVD training sets).
70 .. note::
72 This is a convenience function for our own dataset definitions inside
73 this module, guaranteeting homogenity between dataset definitions
74 provided in this package. It assumes certain strategies for data
75 augmentation that may not be translatable to other applications.
78 Parameters
79 ----------
81 s : deepdraw.data.utils.SampleListDataset
82 A dataset that will be augmented
84 rotation_before : py:class:`bool`, Optional
85 A optional flag allowing you to do a rotation augmentation transform
86 **before** the sequence of transforms for this dataset, that will be
87 augmented.
90 Returns
91 -------
93 subset : :py:class:`deepdraw.data.utils.SampleListDataset`
94 A pre-formatted dataset that can be fed to one of our engines
95 """
97 if rotation_before:
98 return s.copy(RANDOM_ROTATION + s.transforms + RANDOM_FLIP_JITTER)
100 return s.copy(s.transforms + RANDOM_ROTATION + RANDOM_FLIP_JITTER)
103def make_dataset(subsets, transforms):
104 """Creates a new configuration dataset from dictionary and transforms.
106 This function takes as input a dictionary as those that can be returned by
107 :py:meth:`deepdraw.data.dataset.JSONDataset.subsets`, or
108 :py:meth:`deepdraw.data.dataset.CSVDataset.subsets`, mapping protocol
109 names (such as ``train``, ``dev`` and ``test``) to
110 :py:class:`deepdraw.data.sample.DelayedSample` lists, and a set of
111 transforms, and returns a dictionary applying
112 :py:class:`deepdraw.data.utils.SampleListDataset` to these
113 lists, and our standard data augmentation if a ``train`` set exists.
115 For example, if ``subsets`` is composed of two sets named ``train`` and
116 ``test``, this function will yield a dictionary with the following entries:
118 * ``__train__``: Wraps the ``train`` subset, includes data augmentation
119 (note: datasets with names starting with ``_`` (underscore) are excluded
120 from prediction and evaluation by default, as they contain data
121 augmentation transformations.)
122 * ``train``: Wraps the ``train`` subset, **without** data augmentation
123 * ``train``: Wraps the ``test`` subset, **without** data augmentation
125 .. note::
127 This is a convenience function for our own dataset definitions inside
128 this module, guaranteeting homogenity between dataset definitions
129 provided in this package. It assumes certain strategies for data
130 augmentation that may not be translatable to other applications.
133 Parameters
134 ----------
136 subsets : dict
137 A dictionary that contains the delayed sample lists for a number of
138 named lists. If one of the keys is ``train``, our standard dataset
139 augmentation transforms are appended to the definition of that subset.
140 All other subsets remain un-augmented. If one of the keys is
141 ``validation``, then this dataset will be also copied to the
142 ``__valid__`` hidden dataset and will be used for validation during
143 training. Otherwise, if no ``valid`` subset is available, we set
144 ``__valid__`` to be the same as the unaugmented ``train`` subset, if
145 one is available.
147 transforms : list
148 A list of transforms that needs to be applied to all samples in the set
151 Returns
152 -------
154 dataset : dict
155 A pre-formatted dataset that can be fed to one of our engines. It maps
156 string names to
157 :py:class:`deepdraw.data.utils.SampleListDataset`'s.
158 """
160 retval = {}
162 for key in subsets.keys():
163 retval[key] = make_subset(subsets[key], transforms=transforms)
164 if key == "train":
165 retval["__train__"] = make_subset(
166 subsets[key],
167 transforms=transforms,
168 suffixes=(RANDOM_ROTATION + RANDOM_FLIP_JITTER),
169 )
170 if key == "validation":
171 # also use it for validation during training
172 retval["__valid__"] = retval[key]
174 if (
175 ("__train__" in retval)
176 and ("train" in retval)
177 and ("__valid__" not in retval)
178 ):
179 # if the dataset does not have a validation set, we use the unaugmented
180 # training set as validation set
181 retval["__valid__"] = retval["train"]
183 return retval