Coverage for src/bob/bio/face/pytorch/facexzoo/EfficientNets.py: 65%

369 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-13 00:04 +0200

1""" 

2@author: Jun Wang 

3@date: 20201019 

4@contact: jun21wangustc@gmail.com 

5""" 

6 

7# based on: 

8# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 

9 

10import collections 

11import math 

12import re 

13 

14from functools import partial 

15 

16import torch 

17 

18from torch import nn 

19from torch.nn import ( 

20 BatchNorm1d, 

21 BatchNorm2d, 

22 Dropout, 

23 Linear, 

24 Module, 

25 Sequential, 

26) 

27from torch.nn import functional as F 

28from torch.utils import model_zoo 

29 

30################################################################################ 

31# Help functions for model architecture 

32################################################################################ 

33 

34# GlobalParams and BlockArgs: Two namedtuples 

35# Swish and MemoryEfficientSwish: Two implementations of the method 

36# round_filters and round_repeats: 

37# Functions to calculate params for scaling model width and depth ! ! ! 

38# get_width_and_height_from_size and calculate_output_image_size 

39# drop_connect: A structural design 

40# get_same_padding_conv2d: 

41# Conv2dDynamicSamePadding 

42# Conv2dStaticSamePadding 

43# get_same_padding_maxPool2d: 

44# MaxPool2dDynamicSamePadding 

45# MaxPool2dStaticSamePadding 

46# It's an additional function, not used in EfficientNet, 

47# but can be used in other model (such as EfficientDet). 

48 

49# Parameters for the entire model (stem, all blocks, and head) 

50GlobalParams = collections.namedtuple( 

51 "GlobalParams", 

52 [ 

53 "width_coefficient", 

54 "depth_coefficient", 

55 "image_size", 

56 "dropout_rate", 

57 "num_classes", 

58 "batch_norm_momentum", 

59 "batch_norm_epsilon", 

60 "drop_connect_rate", 

61 "depth_divisor", 

62 "min_depth", 

63 "include_top", 

64 ], 

65) 

66 

67# Parameters for an individual model block 

68BlockArgs = collections.namedtuple( 

69 "BlockArgs", 

70 [ 

71 "num_repeat", 

72 "kernel_size", 

73 "stride", 

74 "expand_ratio", 

75 "input_filters", 

76 "output_filters", 

77 "se_ratio", 

78 "id_skip", 

79 ], 

80) 

81 

82# Set GlobalParams and BlockArgs's defaults 

83GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 

84BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 

85 

86 

87# An ordinary implementation of Swish function 

88class Swish(nn.Module): 

89 def forward(self, x): 

90 return x * torch.sigmoid(x) 

91 

92 

93# A memory-efficient implementation of Swish function 

94class SwishImplementation(torch.autograd.Function): 

95 @staticmethod 

96 def forward(ctx, i): 

97 result = i * torch.sigmoid(i) 

98 ctx.save_for_backward(i) 

99 return result 

100 

101 @staticmethod 

102 def backward(ctx, grad_output): 

103 i = ctx.saved_tensors[0] 

104 sigmoid_i = torch.sigmoid(i) 

105 return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 

106 

107 

108class MemoryEfficientSwish(nn.Module): 

109 def forward(self, x): 

110 return SwishImplementation.apply(x) 

111 

112 

113def round_filters(filters, global_params): 

114 """Calculate and round number of filters based on width multiplier. 

115 Use width_coefficient, depth_divisor and min_depth of global_params. 

116 

117 Args: 

118 filters (int): Filters number to be calculated. 

119 global_params (namedtuple): Global params of the model. 

120 

121 Returns: 

122 new_filters: New filters number after calculating. 

123 """ 

124 multiplier = global_params.width_coefficient 

125 if not multiplier: 

126 return filters 

127 # TODO: modify the params names. 

128 # maybe the names (width_divisor,min_width) 

129 # are more suitable than (depth_divisor,min_depth). 

130 divisor = global_params.depth_divisor 

131 min_depth = global_params.min_depth 

132 filters *= multiplier 

133 min_depth = ( 

134 min_depth or divisor 

135 ) # pay attention to this line when using min_depth 

136 # follow the formula transferred from official TensorFlow implementation 

137 new_filters = max( 

138 min_depth, int(filters + divisor / 2) // divisor * divisor 

139 ) 

140 if new_filters < 0.9 * filters: # prevent rounding by more than 10% 

141 new_filters += divisor 

142 return int(new_filters) 

143 

144 

145def round_repeats(repeats, global_params): 

146 """Calculate module's repeat number of a block based on depth multiplier. 

147 Use depth_coefficient of global_params. 

148 

149 Args: 

150 repeats (int): num_repeat to be calculated. 

151 global_params (namedtuple): Global params of the model. 

152 

153 Returns: 

154 new repeat: New repeat number after calculating. 

155 """ 

156 multiplier = global_params.depth_coefficient 

157 if not multiplier: 

158 return repeats 

159 # follow the formula transferred from official TensorFlow implementation 

160 return int(math.ceil(multiplier * repeats)) 

161 

162 

163def drop_connect(inputs, p, training): 

164 """Drop connect. 

165 

166 Args: 

167 input (tensor: BCWH): Input of this structure. 

168 p (float: 0.0~1.0): Probability of drop connection. 

169 training (bool): The running mode. 

170 

171 Returns: 

172 output: Output after drop connection. 

173 """ 

174 assert 0 <= p <= 1, "p must be in range of [0,1]" 

175 

176 if not training: 

177 return inputs 

178 

179 batch_size = inputs.shape[0] 

180 keep_prob = 1 - p 

181 

182 # generate binary_tensor mask according to probability (p for 0, 1-p for 1) 

183 random_tensor = keep_prob 

184 random_tensor += torch.rand( 

185 [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device 

186 ) 

187 binary_tensor = torch.floor(random_tensor) 

188 

189 output = inputs / keep_prob * binary_tensor 

190 return output 

191 

192 

193def get_width_and_height_from_size(x): 

194 """Obtain height and width from x. 

195 

196 Args: 

197 x (int, tuple or list): Data size. 

198 

199 Returns: 

200 size: A tuple or list (H,W). 

201 """ 

202 if isinstance(x, int): 

203 return x, x 

204 if isinstance(x, list) or isinstance(x, tuple): 

205 return x 

206 else: 

207 raise TypeError() 

208 

209 

210def calculate_output_image_size(input_image_size, stride): 

211 """Calculates the output image size when using Conv2dSamePadding with a stride. 

212 Necessary for static padding. Thanks to mannatsingh for pointing this out. 

213 

214 Args: 

215 input_image_size (int, tuple or list): Size of input image. 

216 stride (int, tuple or list): Conv2d operation's stride. 

217 

218 Returns: 

219 output_image_size: A list [H,W]. 

220 """ 

221 if input_image_size is None: 

222 return None 

223 image_height, image_width = get_width_and_height_from_size(input_image_size) 

224 stride = stride if isinstance(stride, int) else stride[0] 

225 image_height = int(math.ceil(image_height / stride)) 

226 image_width = int(math.ceil(image_width / stride)) 

227 return [image_height, image_width] 

228 

229 

230# Note: 

231# The following 'SamePadding' functions make output size equal ceil(input size/stride). 

232# Only when stride equals 1, can the output size be the same as input size. 

233# Don't be confused by their function names ! ! ! 

234 

235 

236def get_same_padding_conv2d(image_size=None): 

237 """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 

238 Static padding is necessary for ONNX exporting of models. 

239 

240 Args: 

241 image_size (int or tuple): Size of the image. 

242 

243 Returns: 

244 Conv2dDynamicSamePadding or Conv2dStaticSamePadding. 

245 """ 

246 if image_size is None: 

247 return Conv2dDynamicSamePadding 

248 else: 

249 return partial(Conv2dStaticSamePadding, image_size=image_size) 

250 

251 

252class Conv2dDynamicSamePadding(nn.Conv2d): 

253 """2D Convolutions like TensorFlow, for a dynamic image size. 

254 The padding is operated in forward function by calculating dynamically. 

255 """ 

256 

257 # Tips for 'SAME' mode padding. 

258 # Given the following: 

259 # i: width or height 

260 # s: stride 

261 # k: kernel size 

262 # d: dilation 

263 # p: padding 

264 # Output after Conv2d: 

265 # o = floor((i+p-((k-1)*d+1))/s+1) 

266 # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), 

267 # => p = (i-1)*s+((k-1)*d+1)-i 

268 

269 def __init__( 

270 self, 

271 in_channels, 

272 out_channels, 

273 kernel_size, 

274 stride=1, 

275 dilation=1, 

276 groups=1, 

277 bias=True, 

278 ): 

279 super().__init__( 

280 in_channels, 

281 out_channels, 

282 kernel_size, 

283 stride, 

284 0, 

285 dilation, 

286 groups, 

287 bias, 

288 ) 

289 self.stride = ( 

290 self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 

291 ) 

292 

293 def forward(self, x): 

294 ih, iw = x.size()[-2:] 

295 kh, kw = self.weight.size()[-2:] 

296 sh, sw = self.stride 

297 oh, ow = math.ceil(ih / sh), math.ceil( 

298 iw / sw 

299 ) # change the output size according to stride ! ! ! 

300 pad_h = max( 

301 (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0 

302 ) 

303 pad_w = max( 

304 (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0 

305 ) 

306 if pad_h > 0 or pad_w > 0: 

307 x = F.pad( 

308 x, 

309 [ 

310 pad_w // 2, 

311 pad_w - pad_w // 2, 

312 pad_h // 2, 

313 pad_h - pad_h // 2, 

314 ], 

315 ) 

316 return F.conv2d( 

317 x, 

318 self.weight, 

319 self.bias, 

320 self.stride, 

321 self.padding, 

322 self.dilation, 

323 self.groups, 

324 ) 

325 

326 

327class Conv2dStaticSamePadding(nn.Conv2d): 

328 """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. 

329 The padding mudule is calculated in construction function, then used in forward. 

330 """ 

331 

332 # With the same calculation as Conv2dDynamicSamePadding 

333 

334 def __init__( 

335 self, 

336 in_channels, 

337 out_channels, 

338 kernel_size, 

339 stride=1, 

340 image_size=None, 

341 **kwargs, 

342 ): 

343 super().__init__( 

344 in_channels, out_channels, kernel_size, stride, **kwargs 

345 ) 

346 self.stride = ( 

347 self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 

348 ) 

349 

350 # Calculate padding based on image size and save it 

351 assert image_size is not None 

352 ih, iw = ( 

353 (image_size, image_size) 

354 if isinstance(image_size, int) 

355 else image_size 

356 ) 

357 kh, kw = self.weight.size()[-2:] 

358 sh, sw = self.stride 

359 oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 

360 pad_h = max( 

361 (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0 

362 ) 

363 pad_w = max( 

364 (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0 

365 ) 

366 if pad_h > 0 or pad_w > 0: 

367 self.static_padding = nn.ZeroPad2d( 

368 (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) 

369 ) 

370 else: 

371 self.static_padding = nn.Identity() 

372 

373 def forward(self, x): 

374 x = self.static_padding(x) 

375 x = F.conv2d( 

376 x, 

377 self.weight, 

378 self.bias, 

379 self.stride, 

380 self.padding, 

381 self.dilation, 

382 self.groups, 

383 ) 

384 return x 

385 

386 

387def get_same_padding_maxPool2d(image_size=None): 

388 """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 

389 Static padding is necessary for ONNX exporting of models. 

390 

391 Args: 

392 image_size (int or tuple): Size of the image. 

393 

394 Returns: 

395 MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. 

396 """ 

397 if image_size is None: 

398 return MaxPool2dDynamicSamePadding 

399 else: 

400 return partial(MaxPool2dStaticSamePadding, image_size=image_size) 

401 

402 

403class MaxPool2dDynamicSamePadding(nn.MaxPool2d): 

404 """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. 

405 The padding is operated in forward function by calculating dynamically. 

406 """ 

407 

408 def __init__( 

409 self, 

410 kernel_size, 

411 stride, 

412 padding=0, 

413 dilation=1, 

414 return_indices=False, 

415 ceil_mode=False, 

416 ): 

417 super().__init__( 

418 kernel_size, stride, padding, dilation, return_indices, ceil_mode 

419 ) 

420 self.stride = ( 

421 [self.stride] * 2 if isinstance(self.stride, int) else self.stride 

422 ) 

423 self.kernel_size = ( 

424 [self.kernel_size] * 2 

425 if isinstance(self.kernel_size, int) 

426 else self.kernel_size 

427 ) 

428 self.dilation = ( 

429 [self.dilation] * 2 

430 if isinstance(self.dilation, int) 

431 else self.dilation 

432 ) 

433 

434 def forward(self, x): 

435 ih, iw = x.size()[-2:] 

436 kh, kw = self.kernel_size 

437 sh, sw = self.stride 

438 oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 

439 pad_h = max( 

440 (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0 

441 ) 

442 pad_w = max( 

443 (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0 

444 ) 

445 if pad_h > 0 or pad_w > 0: 

446 x = F.pad( 

447 x, 

448 [ 

449 pad_w // 2, 

450 pad_w - pad_w // 2, 

451 pad_h // 2, 

452 pad_h - pad_h // 2, 

453 ], 

454 ) 

455 return F.max_pool2d( 

456 x, 

457 self.kernel_size, 

458 self.stride, 

459 self.padding, 

460 self.dilation, 

461 self.ceil_mode, 

462 self.return_indices, 

463 ) 

464 

465 

466class MaxPool2dStaticSamePadding(nn.MaxPool2d): 

467 """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. 

468 The padding mudule is calculated in construction function, then used in forward. 

469 """ 

470 

471 def __init__(self, kernel_size, stride, image_size=None, **kwargs): 

472 super().__init__(kernel_size, stride, **kwargs) 

473 self.stride = ( 

474 [self.stride] * 2 if isinstance(self.stride, int) else self.stride 

475 ) 

476 self.kernel_size = ( 

477 [self.kernel_size] * 2 

478 if isinstance(self.kernel_size, int) 

479 else self.kernel_size 

480 ) 

481 self.dilation = ( 

482 [self.dilation] * 2 

483 if isinstance(self.dilation, int) 

484 else self.dilation 

485 ) 

486 

487 # Calculate padding based on image size and save it 

488 assert image_size is not None 

489 ih, iw = ( 

490 (image_size, image_size) 

491 if isinstance(image_size, int) 

492 else image_size 

493 ) 

494 kh, kw = self.kernel_size 

495 sh, sw = self.stride 

496 oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 

497 pad_h = max( 

498 (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0 

499 ) 

500 pad_w = max( 

501 (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0 

502 ) 

503 if pad_h > 0 or pad_w > 0: 

504 self.static_padding = nn.ZeroPad2d( 

505 (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) 

506 ) 

507 else: 

508 self.static_padding = nn.Identity() 

509 

510 def forward(self, x): 

511 x = self.static_padding(x) 

512 x = F.max_pool2d( 

513 x, 

514 self.kernel_size, 

515 self.stride, 

516 self.padding, 

517 self.dilation, 

518 self.ceil_mode, 

519 self.return_indices, 

520 ) 

521 return x 

522 

523 

524################################################################################ 

525# Helper functions for loading model params 

526################################################################################ 

527 

528# BlockDecoder: A Class for encoding and decoding BlockArgs 

529# efficientnet_params: A function to query compound coefficient 

530# get_model_params and efficientnet: 

531# Functions to get BlockArgs and GlobalParams for efficientnet 

532# url_map and url_map_advprop: Dicts of url_map for pretrained weights 

533# load_pretrained_weights: A function to load pretrained weights 

534 

535 

536class BlockDecoder(object): 

537 """Block Decoder for readability, 

538 straight from the official TensorFlow repository. 

539 """ 

540 

541 @staticmethod 

542 def _decode_block_string(block_string): 

543 """Get a block through a string notation of arguments. 

544 

545 Args: 

546 block_string (str): A string notation of arguments. 

547 Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. 

548 

549 Returns: 

550 BlockArgs: The namedtuple defined at the top of this file. 

551 """ 

552 assert isinstance(block_string, str) 

553 

554 ops = block_string.split("_") 

555 options = {} 

556 for op in ops: 

557 splits = re.split(r"(\d.*)", op) 

558 if len(splits) >= 2: 

559 key, value = splits[:2] 

560 options[key] = value 

561 

562 # Check stride 

563 assert ("s" in options and len(options["s"]) == 1) or ( 

564 len(options["s"]) == 2 and options["s"][0] == options["s"][1] 

565 ) 

566 

567 return BlockArgs( 

568 num_repeat=int(options["r"]), 

569 kernel_size=int(options["k"]), 

570 stride=[int(options["s"][0])], 

571 expand_ratio=int(options["e"]), 

572 input_filters=int(options["i"]), 

573 output_filters=int(options["o"]), 

574 se_ratio=float(options["se"]) if "se" in options else None, 

575 id_skip=("noskip" not in block_string), 

576 ) 

577 

578 @staticmethod 

579 def _encode_block_string(block): 

580 """Encode a block to a string. 

581 

582 Args: 

583 block (namedtuple): A BlockArgs type argument. 

584 

585 Returns: 

586 block_string: A String form of BlockArgs. 

587 """ 

588 args = [ 

589 "r%d" % block.num_repeat, 

590 "k%d" % block.kernel_size, 

591 "s%d%d" % (block.strides[0], block.strides[1]), 

592 "e%s" % block.expand_ratio, 

593 "i%d" % block.input_filters, 

594 "o%d" % block.output_filters, 

595 ] 

596 if 0 < block.se_ratio <= 1: 

597 args.append("se%s" % block.se_ratio) 

598 if block.id_skip is False: 

599 args.append("noskip") 

600 return "_".join(args) 

601 

602 @staticmethod 

603 def decode(string_list): 

604 """Decode a list of string notations to specify blocks inside the network. 

605 

606 Args: 

607 string_list (list[str]): A list of strings, each string is a notation of block. 

608 

609 Returns: 

610 blocks_args: A list of BlockArgs namedtuples of block args. 

611 """ 

612 assert isinstance(string_list, list) 

613 blocks_args = [] 

614 for block_string in string_list: 

615 blocks_args.append(BlockDecoder._decode_block_string(block_string)) 

616 return blocks_args 

617 

618 @staticmethod 

619 def encode(blocks_args): 

620 """Encode a list of BlockArgs to a list of strings. 

621 

622 Args: 

623 blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. 

624 

625 Returns: 

626 block_strings: A list of strings, each string is a notation of block. 

627 """ 

628 block_strings = [] 

629 for block in blocks_args: 

630 block_strings.append(BlockDecoder._encode_block_string(block)) 

631 return block_strings 

632 

633 

634def efficientnet_params(model_name): 

635 """Map EfficientNet model name to parameter coefficients. 

636 

637 Args: 

638 model_name (str): Model name to be queried. 

639 

640 Returns: 

641 params_dict[model_name]: A (width,depth,res,dropout) tuple. 

642 """ 

643 """ 

644 params_dict = { 

645 # Coefficients: width,depth,res,dropout 

646 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 

647 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 

648 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 

649 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 

650 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 

651 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 

652 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 

653 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 

654 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 

655 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 

656 } 

657 """ 

658 params_dict = { 

659 # Coefficients: width,depth,res,dropout 

660 "efficientnet-b0": (1.0, 1.0, 112, 0.2), 

661 "efficientnet-b1": (1.0, 1.1, 112, 0.2), 

662 "efficientnet-b2": (1.1, 1.2, 112, 0.3), 

663 "efficientnet-b3": (1.2, 1.4, 112, 0.3), 

664 "efficientnet-b4": (1.4, 1.8, 112, 0.4), 

665 "efficientnet-b5": (1.6, 2.2, 112, 0.4), 

666 "efficientnet-b6": (1.8, 2.6, 112, 0.5), 

667 "efficientnet-b7": (2.0, 3.1, 112, 0.5), 

668 "efficientnet-b8": (2.2, 3.6, 112, 0.5), 

669 "efficientnet-l2": (4.3, 5.3, 112, 0.5), 

670 } 

671 return params_dict[model_name] 

672 

673 

674def efficientnet( 

675 width_coefficient=None, 

676 depth_coefficient=None, 

677 image_size=None, 

678 dropout_rate=0.2, 

679 drop_connect_rate=0.2, 

680 num_classes=1000, 

681 include_top=True, 

682): 

683 """Create BlockArgs and GlobalParams for efficientnet model. 

684 

685 Args: 

686 width_coefficient (float) 

687 depth_coefficient (float) 

688 image_size (int) 

689 dropout_rate (float) 

690 drop_connect_rate (float) 

691 num_classes (int) 

692 

693 Meaning as the name suggests. 

694 

695 Returns: 

696 blocks_args, global_params. 

697 """ 

698 

699 # Blocks args for the whole model(efficientnet-b0 by default) 

700 # It will be modified in the construction of EfficientNet Class according to model 

701 blocks_args = [ 

702 "r1_k3_s11_e1_i32_o16_se0.25", 

703 "r2_k3_s22_e6_i16_o24_se0.25", 

704 "r2_k5_s22_e6_i24_o40_se0.25", 

705 "r3_k3_s22_e6_i40_o80_se0.25", 

706 "r3_k5_s11_e6_i80_o112_se0.25", 

707 "r4_k5_s22_e6_i112_o192_se0.25", 

708 "r1_k3_s11_e6_i192_o320_se0.25", 

709 ] 

710 blocks_args = BlockDecoder.decode(blocks_args) 

711 

712 global_params = GlobalParams( 

713 width_coefficient=width_coefficient, 

714 depth_coefficient=depth_coefficient, 

715 image_size=image_size, 

716 dropout_rate=dropout_rate, 

717 num_classes=num_classes, 

718 batch_norm_momentum=0.99, 

719 batch_norm_epsilon=1e-3, 

720 drop_connect_rate=drop_connect_rate, 

721 depth_divisor=8, 

722 min_depth=None, 

723 include_top=include_top, 

724 ) 

725 

726 return blocks_args, global_params 

727 

728 

729def get_model_params(model_name, override_params): 

730 """Get the block args and global params for a given model name. 

731 

732 Args: 

733 model_name (str): Model's name. 

734 override_params (dict): A dict to modify global_params. 

735 

736 Returns: 

737 blocks_args, global_params 

738 """ 

739 if model_name.startswith("efficientnet"): 

740 w, d, s, p = efficientnet_params(model_name) 

741 # note: all models have drop connect rate = 0.2 

742 blocks_args, global_params = efficientnet( 

743 width_coefficient=w, 

744 depth_coefficient=d, 

745 dropout_rate=p, 

746 image_size=s, 

747 ) 

748 else: 

749 raise NotImplementedError( 

750 "model name is not pre-defined: {}".format(model_name) 

751 ) 

752 if override_params: 

753 # ValueError will be raised here if override_params has fields not included in global_params. 

754 global_params = global_params._replace(**override_params) 

755 return blocks_args, global_params 

756 

757 

758# train with Standard methods 

759# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) 

760url_map = { 

761 "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", 

762 "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", 

763 "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", 

764 "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", 

765 "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", 

766 "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", 

767 "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", 

768 "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", 

769} 

770 

771# train with Adversarial Examples(AdvProp) 

772# check more details in paper(Adversarial Examples Improve Image Recognition) 

773url_map_advprop = { 

774 "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", 

775 "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", 

776 "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", 

777 "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", 

778 "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", 

779 "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", 

780 "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", 

781 "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", 

782 "efficientnet-b8": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth", 

783} 

784 

785# TODO: add the petrained weights url map of 'efficientnet-l2' 

786 

787 

788def load_pretrained_weights( 

789 model, model_name, weights_path=None, load_fc=True, advprop=False 

790): 

791 """Loads pretrained weights from weights path or download using url. 

792 

793 Args: 

794 model (Module): The whole model of efficientnet. 

795 model_name (str): Model name of efficientnet. 

796 weights_path (None or str): 

797 str: path to pretrained weights file on the local disk. 

798 None: use pretrained weights downloaded from the Internet. 

799 load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 

800 advprop (bool): Whether to load pretrained weights 

801 trained with advprop (valid when weights_path is None). 

802 """ 

803 if isinstance(weights_path, str): 

804 state_dict = torch.load(weights_path) 

805 else: 

806 # AutoAugment or Advprop (different preprocessing) 

807 url_map_ = url_map_advprop if advprop else url_map 

808 state_dict = model_zoo.load_url(url_map_[model_name]) 

809 

810 if load_fc: 

811 ret = model.load_state_dict(state_dict, strict=False) 

812 assert ( 

813 not ret.missing_keys 

814 ), "Missing keys when loading pretrained weights: {}".format( 

815 ret.missing_keys 

816 ) 

817 else: 

818 state_dict.pop("_fc.weight") 

819 state_dict.pop("_fc.bias") 

820 ret = model.load_state_dict(state_dict, strict=False) 

821 assert set(ret.missing_keys) == set( 

822 ["_fc.weight", "_fc.bias"] 

823 ), "Missing keys when loading pretrained weights: {}".format( 

824 ret.missing_keys 

825 ) 

826 assert ( 

827 not ret.unexpected_keys 

828 ), "Missing keys when loading pretrained weights: {}".format( 

829 ret.unexpected_keys 

830 ) 

831 

832 print("Loaded pretrained weights for {}".format(model_name)) 

833 

834 

835class Flatten(Module): 

836 def forward(self, input): 

837 return input.view(input.size(0), -1) 

838 

839 

840# backbone ################################################################## 

841 

842VALID_MODELS = ( 

843 "efficientnet-b0", 

844 "efficientnet-b1", 

845 "efficientnet-b2", 

846 "efficientnet-b3", 

847 "efficientnet-b4", 

848 "efficientnet-b5", 

849 "efficientnet-b6", 

850 "efficientnet-b7", 

851 "efficientnet-b8", 

852 # Support the construction of 'efficientnet-l2' without pretrained weights 

853 "efficientnet-l2", 

854) 

855 

856 

857class MBConvBlock(nn.Module): 

858 """Mobile Inverted Residual Bottleneck Block. 

859 

860 Args: 

861 block_args (namedtuple): BlockArgs, defined in utils.py. 

862 global_params (namedtuple): GlobalParam, defined in utils.py. 

863 image_size (tuple or list): [image_height, image_width]. 

864 

865 References: 

866 [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 

867 [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 

868 [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 

869 """ 

870 

871 def __init__(self, block_args, global_params, image_size=None): 

872 super().__init__() 

873 self._block_args = block_args 

874 self._bn_mom = ( 

875 1 - global_params.batch_norm_momentum 

876 ) # pytorch's difference from tensorflow 

877 self._bn_eps = global_params.batch_norm_epsilon 

878 self.has_se = (self._block_args.se_ratio is not None) and ( 

879 0 < self._block_args.se_ratio <= 1 

880 ) 

881 self.id_skip = ( 

882 block_args.id_skip 

883 ) # whether to use skip connection and drop connect 

884 

885 # Expansion phase (Inverted Bottleneck) 

886 inp = self._block_args.input_filters # number of input channels 

887 oup = ( 

888 self._block_args.input_filters * self._block_args.expand_ratio 

889 ) # number of output channels 

890 if self._block_args.expand_ratio != 1: 

891 Conv2d = get_same_padding_conv2d(image_size=image_size) 

892 self._expand_conv = Conv2d( 

893 in_channels=inp, out_channels=oup, kernel_size=1, bias=False 

894 ) 

895 self._bn0 = nn.BatchNorm2d( 

896 num_features=oup, momentum=self._bn_mom, eps=self._bn_eps 

897 ) 

898 # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size 

899 

900 # Depthwise convolution phase 

901 k = self._block_args.kernel_size 

902 s = self._block_args.stride 

903 Conv2d = get_same_padding_conv2d(image_size=image_size) 

904 self._depthwise_conv = Conv2d( 

905 in_channels=oup, 

906 out_channels=oup, 

907 groups=oup, # groups makes it depthwise 

908 kernel_size=k, 

909 stride=s, 

910 bias=False, 

911 ) 

912 self._bn1 = nn.BatchNorm2d( 

913 num_features=oup, momentum=self._bn_mom, eps=self._bn_eps 

914 ) 

915 image_size = calculate_output_image_size(image_size, s) 

916 

917 # Squeeze and Excitation layer, if desired 

918 if self.has_se: 

919 Conv2d = get_same_padding_conv2d(image_size=(1, 1)) 

920 num_squeezed_channels = max( 

921 1, 

922 int(self._block_args.input_filters * self._block_args.se_ratio), 

923 ) 

924 self._se_reduce = Conv2d( 

925 in_channels=oup, 

926 out_channels=num_squeezed_channels, 

927 kernel_size=1, 

928 ) 

929 self._se_expand = Conv2d( 

930 in_channels=num_squeezed_channels, 

931 out_channels=oup, 

932 kernel_size=1, 

933 ) 

934 

935 # Pointwise convolution phase 

936 final_oup = self._block_args.output_filters 

937 Conv2d = get_same_padding_conv2d(image_size=image_size) 

938 self._project_conv = Conv2d( 

939 in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False 

940 ) 

941 self._bn2 = nn.BatchNorm2d( 

942 num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps 

943 ) 

944 self._swish = MemoryEfficientSwish() 

945 

946 def forward(self, inputs, drop_connect_rate=None): 

947 """MBConvBlock's forward function. 

948 

949 Args: 

950 inputs (tensor): Input tensor. 

951 drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). 

952 

953 Returns: 

954 Output of this block after processing. 

955 """ 

956 

957 # Expansion and Depthwise Convolution 

958 x = inputs 

959 if self._block_args.expand_ratio != 1: 

960 x = self._expand_conv(inputs) 

961 x = self._bn0(x) 

962 x = self._swish(x) 

963 

964 x = self._depthwise_conv(x) 

965 x = self._bn1(x) 

966 x = self._swish(x) 

967 

968 # Squeeze and Excitation 

969 if self.has_se: 

970 x_squeezed = F.adaptive_avg_pool2d(x, 1) 

971 x_squeezed = self._se_reduce(x_squeezed) 

972 x_squeezed = self._swish(x_squeezed) 

973 x_squeezed = self._se_expand(x_squeezed) 

974 x = torch.sigmoid(x_squeezed) * x 

975 

976 # Pointwise Convolution 

977 x = self._project_conv(x) 

978 x = self._bn2(x) 

979 

980 # Skip connection and drop connect 

981 input_filters, output_filters = ( 

982 self._block_args.input_filters, 

983 self._block_args.output_filters, 

984 ) 

985 if ( 

986 self.id_skip 

987 and self._block_args.stride == 1 

988 and input_filters == output_filters 

989 ): 

990 # The combination of skip connection and drop connect brings about stochastic depth. 

991 if drop_connect_rate: 

992 x = drop_connect(x, p=drop_connect_rate, training=self.training) 

993 x = x + inputs # skip connection 

994 return x 

995 

996 def set_swish(self, memory_efficient=True): 

997 """Sets swish function as memory efficient (for training) or standard (for export). 

998 

999 Args: 

1000 memory_efficient (bool): Whether to use memory-efficient version of swish. 

1001 """ 

1002 self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 

1003 

1004 

1005class EfficientNet(nn.Module): 

1006 """EfficientNet model. 

1007 Most easily loaded with the .from_name or .from_pretrained methods. 

1008 

1009 Args: 

1010 blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. 

1011 global_params (namedtuple): A set of GlobalParams shared between blocks. 

1012 

1013 References: 

1014 [1] https://arxiv.org/abs/1905.11946 (EfficientNet) 

1015 

1016 Example: 

1017 

1018 

1019 import torch 

1020 >>> from efficientnet.model import EfficientNet 

1021 >>> inputs = torch.rand(1, 3, 224, 224) 

1022 >>> model = EfficientNet.from_pretrained('efficientnet-b0') 

1023 >>> model.eval() 

1024 >>> outputs = model(inputs) 

1025 """ 

1026 

1027 def __init__( 

1028 self, out_h, out_w, feat_dim, blocks_args=None, global_params=None 

1029 ): 

1030 super().__init__() 

1031 assert isinstance(blocks_args, list), "blocks_args should be a list" 

1032 assert len(blocks_args) > 0, "block args must be greater than 0" 

1033 self._global_params = global_params 

1034 self._blocks_args = blocks_args 

1035 

1036 # Batch norm parameters 

1037 bn_mom = 1 - self._global_params.batch_norm_momentum 

1038 bn_eps = self._global_params.batch_norm_epsilon 

1039 

1040 # Get stem static or dynamic convolution depending on image size 

1041 image_size = global_params.image_size 

1042 Conv2d = get_same_padding_conv2d(image_size=image_size) 

1043 

1044 # Stem 

1045 in_channels = 3 # rgb 

1046 out_channels = round_filters( 

1047 32, self._global_params 

1048 ) # number of output channels 

1049 # self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 

1050 self._conv_stem = Conv2d( 

1051 in_channels, out_channels, kernel_size=3, stride=1, bias=False 

1052 ) 

1053 self._bn0 = nn.BatchNorm2d( 

1054 num_features=out_channels, momentum=bn_mom, eps=bn_eps 

1055 ) 

1056 image_size = calculate_output_image_size(image_size, 2) 

1057 

1058 # Build blocks 

1059 self._blocks = nn.ModuleList([]) 

1060 for block_args in self._blocks_args: 

1061 # Update block input and output filters based on depth multiplier. 

1062 block_args = block_args._replace( 

1063 input_filters=round_filters( 

1064 block_args.input_filters, self._global_params 

1065 ), 

1066 output_filters=round_filters( 

1067 block_args.output_filters, self._global_params 

1068 ), 

1069 num_repeat=round_repeats( 

1070 block_args.num_repeat, self._global_params 

1071 ), 

1072 ) 

1073 

1074 # The first block needs to take care of stride and filter size increase. 

1075 self._blocks.append( 

1076 MBConvBlock( 

1077 block_args, self._global_params, image_size=image_size 

1078 ) 

1079 ) 

1080 image_size = calculate_output_image_size( 

1081 image_size, block_args.stride 

1082 ) 

1083 if ( 

1084 block_args.num_repeat > 1 

1085 ): # modify block_args to keep same output size 

1086 block_args = block_args._replace( 

1087 input_filters=block_args.output_filters, stride=1 

1088 ) 

1089 for _ in range(block_args.num_repeat - 1): 

1090 self._blocks.append( 

1091 MBConvBlock( 

1092 block_args, self._global_params, image_size=image_size 

1093 ) 

1094 ) 

1095 # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 

1096 

1097 # Head 

1098 in_channels = block_args.output_filters # output of final block 

1099 out_channels = round_filters(1280, self._global_params) 

1100 # out_channels = round_filters(512, self._global_params) 

1101 Conv2d = get_same_padding_conv2d(image_size=image_size) 

1102 self._conv_head = Conv2d( 

1103 in_channels, out_channels, kernel_size=1, bias=False 

1104 ) 

1105 self._bn1 = nn.BatchNorm2d( 

1106 num_features=out_channels, momentum=bn_mom, eps=bn_eps 

1107 ) 

1108 

1109 # Final linear layer 

1110 self._avg_pooling = nn.AdaptiveAvgPool2d(1) 

1111 self._dropout = nn.Dropout(self._global_params.dropout_rate) 

1112 self._fc = nn.Linear(out_channels, self._global_params.num_classes) 

1113 self._swish = MemoryEfficientSwish() 

1114 self.output_layer = Sequential( 

1115 BatchNorm2d(1280), 

1116 # BatchNorm2d(512), 

1117 Dropout(self._global_params.dropout_rate), 

1118 Flatten(), 

1119 Linear(1280 * out_h * out_w, feat_dim), 

1120 # Linear(512 * out_h * out_w, feat_dim), 

1121 BatchNorm1d(feat_dim), 

1122 ) 

1123 

1124 def set_swish(self, memory_efficient=True): 

1125 """Sets swish function as memory efficient (for training) or standard (for export). 

1126 

1127 Args: 

1128 memory_efficient (bool): Whether to use memory-efficient version of swish. 

1129 

1130 """ 

1131 self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 

1132 for block in self._blocks: 

1133 block.set_swish(memory_efficient) 

1134 

1135 def extract_endpoints(self, inputs): 

1136 """Use convolution layer to extract features 

1137 from reduction levels i in [1, 2, 3, 4, 5]. 

1138 

1139 Args: 

1140 inputs (tensor): Input tensor. 

1141 

1142 Returns: 

1143 Dictionary of last intermediate features 

1144 with reduction levels i in [1, 2, 3, 4, 5]. 

1145 Example: 

1146 >>> import torch 

1147 >>> from efficientnet.model import EfficientNet 

1148 >>> inputs = torch.rand(1, 3, 224, 224) 

1149 >>> model = EfficientNet.from_pretrained('efficientnet-b0') 

1150 >>> endpoints = model.extract_endpoints(inputs) 

1151 >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) 

1152 >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) 

1153 >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) 

1154 >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) 

1155 >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7]) 

1156 """ 

1157 endpoints = dict() 

1158 

1159 # Stem 

1160 x = self._swish(self._bn0(self._conv_stem(inputs))) 

1161 prev_x = x 

1162 

1163 # Blocks 

1164 for idx, block in enumerate(self._blocks): 

1165 drop_connect_rate = self._global_params.drop_connect_rate 

1166 if drop_connect_rate: 

1167 drop_connect_rate *= float(idx) / len( 

1168 self._blocks 

1169 ) # scale drop connect_rate 

1170 x = block(x, drop_connect_rate=drop_connect_rate) 

1171 if prev_x.size(2) > x.size(2): 

1172 endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x 

1173 prev_x = x 

1174 

1175 # Head 

1176 x = self._swish(self._bn1(self._conv_head(x))) 

1177 endpoints["reduction_{}".format(len(endpoints) + 1)] = x 

1178 

1179 return endpoints 

1180 

1181 def extract_features(self, inputs): 

1182 """use convolution layer to extract feature . 

1183 

1184 Args: 

1185 inputs (tensor): Input tensor. 

1186 

1187 Returns: 

1188 Output of the final convolution 

1189 layer in the efficientnet model. 

1190 """ 

1191 # Stem 

1192 x = self._swish(self._bn0(self._conv_stem(inputs))) 

1193 # Blocks 

1194 for idx, block in enumerate(self._blocks): 

1195 drop_connect_rate = self._global_params.drop_connect_rate 

1196 if drop_connect_rate: 

1197 drop_connect_rate *= float(idx) / len( 

1198 self._blocks 

1199 ) # scale drop connect_rate 

1200 x = block(x, drop_connect_rate=drop_connect_rate) 

1201 

1202 # Head 

1203 x = self._swish(self._bn1(self._conv_head(x))) 

1204 

1205 return x 

1206 

1207 def forward(self, inputs): 

1208 """EfficientNet's forward function. 

1209 Calls extract_features to extract features, applies final linear layer, and returns logits. 

1210 

1211 Args: 

1212 inputs (tensor): Input tensor. 

1213 

1214 Returns: 

1215 Output of this model after processing. 

1216 """ 

1217 # Convolution layers 

1218 x = self.extract_features(inputs) 

1219 """ 

1220 # Pooling and final linear layer 

1221 x = self._avg_pooling(x) 

1222 if self._global_params.include_top: 

1223 x = x.flatten(start_dim=1) 

1224 x = self._dropout(x) 

1225 #x = self._fc(x) 

1226 """ 

1227 x = self.output_layer(x) 

1228 return x 

1229 

1230 @classmethod 

1231 def from_name(cls, model_name, in_channels=3, **override_params): 

1232 """create an efficientnet model according to name. 

1233 

1234 Args: 

1235 model_name (str): Name for efficientnet. 

1236 in_channels (int): Input data's channel number. 

1237 override_params (other key word params): 

1238 Params to override model's global_params. 

1239 Optional key: 

1240 'width_coefficient', 'depth_coefficient', 

1241 'image_size', 'dropout_rate', 

1242 'num_classes', 'batch_norm_momentum', 

1243 'batch_norm_epsilon', 'drop_connect_rate', 

1244 'depth_divisor', 'min_depth' 

1245 

1246 Returns: 

1247 An efficientnet model. 

1248 """ 

1249 cls._check_model_name_is_valid(model_name) 

1250 blocks_args, global_params = get_model_params( 

1251 model_name, override_params 

1252 ) 

1253 model = cls(blocks_args, global_params) 

1254 model._change_in_channels(in_channels) 

1255 return model 

1256 

1257 @classmethod 

1258 def from_pretrained( 

1259 cls, 

1260 model_name, 

1261 weights_path=None, 

1262 advprop=False, 

1263 in_channels=3, 

1264 num_classes=1000, 

1265 **override_params, 

1266 ): 

1267 """create an efficientnet model according to name. 

1268 

1269 Args: 

1270 model_name (str): Name for efficientnet. 

1271 weights_path (None or str): 

1272 str: path to pretrained weights file on the local disk. 

1273 None: use pretrained weights downloaded from the Internet. 

1274 advprop (bool): 

1275 Whether to load pretrained weights 

1276 trained with advprop (valid when weights_path is None). 

1277 in_channels (int): Input data's channel number. 

1278 num_classes (int): 

1279 Number of categories for classification. 

1280 It controls the output size for final linear layer. 

1281 override_params (other key word params): 

1282 Params to override model's global_params. 

1283 Optional key: 

1284 'width_coefficient', 'depth_coefficient', 

1285 'image_size', 'dropout_rate', 

1286 'batch_norm_momentum', 

1287 'batch_norm_epsilon', 'drop_connect_rate', 

1288 'depth_divisor', 'min_depth' 

1289 

1290 Returns: 

1291 A pretrained efficientnet model. 

1292 """ 

1293 model = cls.from_name( 

1294 model_name, num_classes=num_classes, **override_params 

1295 ) 

1296 load_pretrained_weights( 

1297 model, 

1298 model_name, 

1299 weights_path=weights_path, 

1300 load_fc=(num_classes == 1000), 

1301 advprop=advprop, 

1302 ) 

1303 model._change_in_channels(in_channels) 

1304 return model 

1305 

1306 @classmethod 

1307 def get_image_size(cls, model_name): 

1308 """Get the input image size for a given efficientnet model. 

1309 

1310 Args: 

1311 model_name (str): Name for efficientnet. 

1312 

1313 Returns: 

1314 Input image size (resolution). 

1315 """ 

1316 cls._check_model_name_is_valid(model_name) 

1317 _, _, res, _ = efficientnet_params(model_name) 

1318 return res 

1319 

1320 @classmethod 

1321 def _check_model_name_is_valid(cls, model_name): 

1322 """Validates model name. 

1323 

1324 Args: 

1325 model_name (str): Name for efficientnet. 

1326 

1327 Returns: 

1328 bool: Is a valid name or not. 

1329 """ 

1330 if model_name not in VALID_MODELS: 

1331 raise ValueError( 

1332 "model_name should be one of: " + ", ".join(VALID_MODELS) 

1333 ) 

1334 

1335 def _change_in_channels(self, in_channels): 

1336 """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 

1337 

1338 Args: 

1339 in_channels (int): Input data's channel number. 

1340 """ 

1341 if in_channels != 3: 

1342 Conv2d = get_same_padding_conv2d( 

1343 image_size=self._global_params.image_size 

1344 ) 

1345 out_channels = round_filters(32, self._global_params) 

1346 self._conv_stem = Conv2d( 

1347 in_channels, out_channels, kernel_size=3, stride=2, bias=False 

1348 )