import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F
from pyiqa.utils.registry import LOSS_REGISTRY
from .loss_util import weighted_loss
_reduction_modes = ['none', 'mean', 'sum']
@weighted_loss
[docs]
def l1_loss(pred, target):
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
[docs]
def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
[docs]
def cross_entropy(pred, target):
return F.cross_entropy(pred, target, reduction='none')
@weighted_loss
[docs]
def nll_loss(pred, target):
return F.nll_loss(pred, target, reduction='none')
@weighted_loss
[docs]
def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target) ** 2 + eps)
@LOSS_REGISTRY.register()
[docs]
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}'
)
self.loss_weight = loss_weight
self.reduction = reduction
[docs]
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * l1_loss(
pred, target, weight, reduction=self.reduction
)
@LOSS_REGISTRY.register()
[docs]
class MSELoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}'
)
self.loss_weight = loss_weight
self.reduction = reduction
[docs]
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * mse_loss(
pred, target, weight, reduction=self.reduction
)
@LOSS_REGISTRY.register()
[docs]
class CrossEntropyLoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}'
)
self.loss_weight = loss_weight
self.reduction = reduction
[docs]
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * cross_entropy(
pred, target, weight, reduction=self.reduction
)
@LOSS_REGISTRY.register()
[docs]
class NLLLoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super().__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}'
)
self.loss_weight = loss_weight
self.reduction = reduction
[docs]
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * nll_loss(
pred, target, weight, reduction=self.reduction
)
@LOSS_REGISTRY.register()
[docs]
class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
eps (float): A value used to control the curvature near zero. Default: 1e-12.
"""
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
super(CharbonnierLoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}'
)
self.loss_weight = loss_weight
self.reduction = reduction
self.eps = eps
[docs]
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * charbonnier_loss(
pred, target, weight, eps=self.eps, reduction=self.reduction
)
@LOSS_REGISTRY.register()
[docs]
class WeightedTVLoss(L1Loss):
"""Weighted TV loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
if reduction not in ['mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum'
)
super(WeightedTVLoss, self).__init__(
loss_weight=loss_weight, reduction=reduction
)
[docs]
def forward(self, pred, weight=None):
if weight is None:
y_weight = None
x_weight = None
else:
y_weight = weight[:, :, :-1, :]
x_weight = weight[:, :, :, :-1]
y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
loss = x_diff + y_diff
return loss