Source code for bob.ip.binseg.modeling.losses

import torch
from torch.nn.modules.loss import _Loss
from torch._jit_internal import weak_script_method




[docs]class WeightedBCELogitsLoss(_Loss): """ Implements Equation 1 in `Maninis et al. (2016)`_. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. Calculate sum of weighted cross entropy loss. """ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(WeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction) self.register_buffer('weight', weight) self.register_buffer('pos_weight', pos_weight)
[docs] @weak_script_method def forward(self, input, target, masks=None): """ Parameters ---------- input : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor` masks : :py:class:`torch.Tensor`, optional Returns ------- :py:class:`torch.Tensor` """ n, c, h, w = target.shape num_pos = torch.sum(target, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1]) if hasattr(masks,'dtype'): num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1]) num_neg = c * h * w - num_pos - num_mask_neg else: num_neg = c * h * w - num_pos numposnumtotal = torch.ones_like(target) * (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2) numnegnumtotal = torch.ones_like(target) * (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2) weight = torch.where((target <= 0.5) , numposnumtotal, numnegnumtotal) loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction) return loss
[docs]class SoftJaccardBCELogitsLoss(_Loss): """ Implements Equation 3 in `Iglovikov et al. (2018)`_. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. Attributes ---------- alpha : float determines the weighting of SoftJaccard and BCE. Default: ``0.7`` """ def __init__(self, alpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) self.alpha = alpha
[docs] @weak_script_method def forward(self, input, target, masks=None): """ Parameters ---------- input : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor` masks : :py:class:`torch.Tensor`, optional Returns ------- :py:class:`torch.Tensor` """ eps = 1e-8 probabilities = torch.sigmoid(input) intersection = (probabilities * target).sum() sums = probabilities.sum() + target.sum() softjaccard = intersection/(sums - intersection + eps) bceloss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, reduction=self.reduction) loss = self.alpha * bceloss + (1 - self.alpha) * (1-softjaccard) return loss
[docs]class HEDWeightedBCELogitsLoss(_Loss): """ Implements Equation 2 in `He et al. (2015)`_. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. Calculate sum of weighted cross entropy loss. """ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(HEDWeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction) self.register_buffer('weight', weight) self.register_buffer('pos_weight', pos_weight)
[docs] @weak_script_method def forward(self, inputlist, target, masks=None): """ Parameters ---------- inputlist : list of :py:class:`torch.Tensor` HED uses multiple side-output feature maps for the loss calculation target : :py:class:`torch.Tensor` masks : :py:class:`torch.Tensor`, optional Returns ------- :py:class:`torch.Tensor` """ loss_over_all_inputs = [] for input in inputlist: n, c, h, w = target.shape num_pos = torch.sum(target, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1]) if hasattr(masks,'dtype'): num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1]) num_neg = c * h * w - num_pos - num_mask_neg else: num_neg = c * h * w - num_pos # torch.Size([n, 1]) numposnumtotal = torch.ones_like(target) * (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2) numnegnumtotal = torch.ones_like(target) * (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2) weight = torch.where((target <= 0.5) , numposnumtotal, numnegnumtotal) loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction) loss_over_all_inputs.append(loss.unsqueeze(0)) final_loss = torch.cat(loss_over_all_inputs).mean() return final_loss
[docs]class HEDSoftJaccardBCELogitsLoss(_Loss): """ Implements Equation 3 in `Iglovikov et al. (2018)`_ for the hed network. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. Attributes ---------- alpha : float determines the weighting of SoftJaccard and BCE. Default: ``0.3`` """ def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(HEDSoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) self.alpha = alpha
[docs] @weak_script_method def forward(self, inputlist, target, masks=None): """ Parameters ---------- input : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor` masks : :py:class:`torch.Tensor`, optional Returns ------- :py:class:`torch.Tensor` """ eps = 1e-8 loss_over_all_inputs = [] for input in inputlist: probabilities = torch.sigmoid(input) intersection = (probabilities * target).sum() sums = probabilities.sum() + target.sum() softjaccard = intersection/(sums - intersection + eps) bceloss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, reduction=self.reduction) loss = self.alpha * bceloss + (1 - self.alpha) * (1-softjaccard) loss_over_all_inputs.append(loss.unsqueeze(0)) final_loss = torch.cat(loss_over_all_inputs).mean() return loss
[docs]class MixJacLoss(_Loss): """ Attributes ---------- lambda_u : int determines the weighting of SoftJaccard and BCE. """ def __init__(self, lambda_u=100, jacalpha=0.7, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(MixJacLoss, self).__init__(size_average, reduce, reduction) self.lambda_u = lambda_u self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha) self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
[docs] @weak_script_method def forward(self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor): """ Parameters ---------- input : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor` unlabeled_input : :py:class:`torch.Tensor` unlabeled_traget : :py:class:`torch.Tensor` ramp_up_factor : float Returns ------- list """ ll = self.labeled_loss(input,target) ul = self.unlabeled_loss(unlabeled_input, unlabeled_traget) loss = ll + self.lambda_u * ramp_up_factor * ul return loss, ll, ul