Coverage for src/bob/bio/face/tensorflow/preprocessing.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1#!/usr/bin/env python 

2# coding: utf-8 

3 

4""" 

5Tensor pre-processing for somr face recognition CNNs 

6""" 

7 

8import logging 

9 

10from functools import partial 

11 

12import tensorflow as tf 

13 

14from tensorflow.keras import layers 

15 

16logger = logging.getLogger(__name__) 

17 

18# STANDARD FEATURES FROM OUR TF-RECORDS 

19FEATURES = { 

20 "data": tf.io.FixedLenFeature([], tf.string), 

21 "label": tf.io.FixedLenFeature([], tf.int64), 

22 "key": tf.io.FixedLenFeature([], tf.string), 

23} 

24 

25 

26def decode_tfrecords(x, data_shape, data_type=tf.uint8): 

27 features = tf.io.parse_single_example(x, FEATURES) 

28 image = tf.io.decode_raw(features["data"], data_type) 

29 image = tf.reshape(image, data_shape) 

30 features["data"] = image 

31 return features 

32 

33 

34def get_preprocessor(output_shape): 

35 """""" 

36 preprocessor = tf.keras.Sequential( 

37 [ 

38 # rotate before cropping 

39 # 5 random degree rotation 

40 layers.experimental.preprocessing.RandomRotation(5 / 360), 

41 layers.experimental.preprocessing.RandomCrop( 

42 height=output_shape[0], width=output_shape[1] 

43 ), 

44 layers.experimental.preprocessing.RandomFlip("horizontal"), 

45 # FIXED_STANDARDIZATION from https://github.com/davidsandberg/facenet 

46 # [-0.99609375, 0.99609375] 

47 # layers.experimental.preprocessing.Rescaling( 

48 # scale=1 / 128, offset=-127.5 / 128 

49 # ), 

50 layers.experimental.preprocessing.Rescaling( 

51 scale=1 / 255, offset=0 

52 ), 

53 ] 

54 ) 

55 return preprocessor 

56 

57 

58def preprocess(preprocessor, features, augment=False): 

59 image = features["data"] 

60 label = features["label"] 

61 image = preprocessor(image, training=augment) 

62 return image, label 

63 

64 

65def prepare_dataset( 

66 tf_record_paths, 

67 batch_size, 

68 epochs, 

69 data_shape, 

70 output_shape, 

71 shuffle=False, 

72 augment=False, 

73 shuffle_buffer=int(2e4), 

74 ctx=None, 

75): 

76 """ 

77 Create batches from a list of TF-Records 

78 

79 Parameters 

80 ---------- 

81 

82 tf_record_paths: list 

83 List of paths of the TF-Records 

84 

85 batch_size: int 

86 

87 epochs: int 

88 

89 shuffle: bool 

90 

91 augment: bool 

92 

93 shuffle_buffer: int 

94 

95 ctx: ``tf.distribute.InputContext`` 

96 """ 

97 

98 ds = tf.data.Dataset.list_files( 

99 tf_record_paths, shuffle=shuffle if ctx is None else False 

100 ) 

101 

102 # if we're in a distributed setting, shard here and shuffle after sharding 

103 if ctx is not None: 

104 batch_size = ctx.get_per_replica_batch_size(batch_size) 

105 ds = ds.shard(ctx.num_replicas_in_sync, ctx.input_pipeline_id) 

106 if shuffle: 

107 ds = ds.shuffle(ds.cardinality()) 

108 

109 ds = tf.data.TFRecordDataset(ds, num_parallel_reads=tf.data.AUTOTUNE) 

110 if shuffle: 

111 # ignore order and read files as soon as they come in 

112 ignore_order = tf.data.Options() 

113 ignore_order.experimental_deterministic = False 

114 ds = ds.with_options(ignore_order) 

115 

116 ds = ds.map( 

117 partial(decode_tfrecords, data_shape=data_shape), 

118 num_parallel_calls=tf.data.AUTOTUNE, 

119 ) 

120 

121 if shuffle: 

122 ds = ds.shuffle(shuffle_buffer) 

123 ds = ds.repeat(epochs) 

124 

125 preprocessor = get_preprocessor(output_shape) 

126 ds = ds.batch(batch_size).map( 

127 partial(preprocess, preprocessor, augment=augment), 

128 num_parallel_calls=tf.data.AUTOTUNE, 

129 ) 

130 

131 # Use buffered prefecting on all datasets 

132 return ds.prefetch(buffer_size=tf.data.AUTOTUNE)