Coverage for src/bob/bio/face/tensorflow/preprocessing.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-13 00:04 +0200
1#!/usr/bin/env python
2# coding: utf-8
4"""
5Tensor pre-processing for somr face recognition CNNs
6"""
8import logging
10from functools import partial
12import tensorflow as tf
14from tensorflow.keras import layers
16logger = logging.getLogger(__name__)
18# STANDARD FEATURES FROM OUR TF-RECORDS
19FEATURES = {
20 "data": tf.io.FixedLenFeature([], tf.string),
21 "label": tf.io.FixedLenFeature([], tf.int64),
22 "key": tf.io.FixedLenFeature([], tf.string),
23}
26def decode_tfrecords(x, data_shape, data_type=tf.uint8):
27 features = tf.io.parse_single_example(x, FEATURES)
28 image = tf.io.decode_raw(features["data"], data_type)
29 image = tf.reshape(image, data_shape)
30 features["data"] = image
31 return features
34def get_preprocessor(output_shape):
35 """"""
36 preprocessor = tf.keras.Sequential(
37 [
38 # rotate before cropping
39 # 5 random degree rotation
40 layers.experimental.preprocessing.RandomRotation(5 / 360),
41 layers.experimental.preprocessing.RandomCrop(
42 height=output_shape[0], width=output_shape[1]
43 ),
44 layers.experimental.preprocessing.RandomFlip("horizontal"),
45 # FIXED_STANDARDIZATION from https://github.com/davidsandberg/facenet
46 # [-0.99609375, 0.99609375]
47 # layers.experimental.preprocessing.Rescaling(
48 # scale=1 / 128, offset=-127.5 / 128
49 # ),
50 layers.experimental.preprocessing.Rescaling(
51 scale=1 / 255, offset=0
52 ),
53 ]
54 )
55 return preprocessor
58def preprocess(preprocessor, features, augment=False):
59 image = features["data"]
60 label = features["label"]
61 image = preprocessor(image, training=augment)
62 return image, label
65def prepare_dataset(
66 tf_record_paths,
67 batch_size,
68 epochs,
69 data_shape,
70 output_shape,
71 shuffle=False,
72 augment=False,
73 shuffle_buffer=int(2e4),
74 ctx=None,
75):
76 """
77 Create batches from a list of TF-Records
79 Parameters
80 ----------
82 tf_record_paths: list
83 List of paths of the TF-Records
85 batch_size: int
87 epochs: int
89 shuffle: bool
91 augment: bool
93 shuffle_buffer: int
95 ctx: ``tf.distribute.InputContext``
96 """
98 ds = tf.data.Dataset.list_files(
99 tf_record_paths, shuffle=shuffle if ctx is None else False
100 )
102 # if we're in a distributed setting, shard here and shuffle after sharding
103 if ctx is not None:
104 batch_size = ctx.get_per_replica_batch_size(batch_size)
105 ds = ds.shard(ctx.num_replicas_in_sync, ctx.input_pipeline_id)
106 if shuffle:
107 ds = ds.shuffle(ds.cardinality())
109 ds = tf.data.TFRecordDataset(ds, num_parallel_reads=tf.data.AUTOTUNE)
110 if shuffle:
111 # ignore order and read files as soon as they come in
112 ignore_order = tf.data.Options()
113 ignore_order.experimental_deterministic = False
114 ds = ds.with_options(ignore_order)
116 ds = ds.map(
117 partial(decode_tfrecords, data_shape=data_shape),
118 num_parallel_calls=tf.data.AUTOTUNE,
119 )
121 if shuffle:
122 ds = ds.shuffle(shuffle_buffer)
123 ds = ds.repeat(epochs)
125 preprocessor = get_preprocessor(output_shape)
126 ds = ds.batch(batch_size).map(
127 partial(preprocess, preprocessor, augment=augment),
128 num_parallel_calls=tf.data.AUTOTUNE,
129 )
131 # Use buffered prefecting on all datasets
132 return ds.prefetch(buffer_size=tf.data.AUTOTUNE)