r"""MUSIQ model.
Reference:
Ke, Junjie, Qifei Wang, Yilin Wang, Peyman Milanfar, and Feng Yang.
"Musiq: Multi-scale image quality transformer." In Proceedings of the
IEEE/CVF International Conference on Computer Vision (ICCV), pp. 5148-5157. 2021.
Ref url: https://github.com/google-research/google-research/tree/master/musiq
Re-implemented by: Chaofeng Chen (https://github.com/chaofengc)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .arch_util import dist_to_mos, load_pretrained_network
from pyiqa.matlab_utils import ExactPadding2d, exact_padding_2d
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.data.multiscale_trans_util import get_multiscale_patches
from pyiqa.archs.arch_util import get_url_from_name
[docs]
default_model_urls = {
'ava': get_url_from_name('musiq_ava_ckpt-e8d3f067.pth'),
'koniq10k': get_url_from_name('musiq_koniq_ckpt-e95806b9.pth'),
'spaq': get_url_from_name('musiq_spaq_ckpt-358bb6af.pth'),
'paq2piq': get_url_from_name('musiq_paq2piq_ckpt-364c0c84.pth'),
'imagenet_pretrain': get_url_from_name('musiq_imagenet_pretrain-51d9b0a5.pth'),
}
[docs]
class StdConv(nn.Conv2d):
"""
Reference: https://github.com/joe-siyuan-qiao/WeightStandardization
"""
[docs]
def forward(self, x):
# implement same padding
x = exact_padding_2d(x, self.kernel_size, self.stride, mode='same')
weight = self.weight
weight = weight - weight.mean((1, 2, 3), keepdim=True)
weight = weight / (weight.std((1, 2, 3), keepdim=True) + 1e-5)
return F.conv2d(x, weight, self.bias, self.stride)
[docs]
class Bottleneck(nn.Module):
def __init__(self, inplanes, outplanes, stride=1):
super().__init__()
width = inplanes
self.conv1 = StdConv(inplanes, width, 1, 1, bias=False)
self.gn1 = nn.GroupNorm(32, width, eps=1e-4)
self.conv2 = StdConv(width, width, 3, 1, bias=False)
self.gn2 = nn.GroupNorm(32, width, eps=1e-4)
self.conv3 = StdConv(width, outplanes, 1, 1, bias=False)
self.gn3 = nn.GroupNorm(32, outplanes, eps=1e-4)
self.relu = nn.ReLU(True)
self.needs_projection = inplanes != outplanes or stride != 1
if self.needs_projection:
self.conv_proj = StdConv(inplanes, outplanes, 1, stride, bias=False)
self.gn_proj = nn.GroupNorm(32, outplanes, eps=1e-4)
[docs]
def forward(self, x):
identity = x
if self.needs_projection:
identity = self.gn_proj(self.conv_proj(identity))
x = self.relu(self.gn1(self.conv1(x)))
x = self.relu(self.gn2(self.conv2(x)))
x = self.gn3(self.conv3(x))
out = self.relu(x + identity)
return out
[docs]
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
[docs]
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
[docs]
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
[docs]
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
[docs]
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads=6, bias=False, attn_drop=0.0, out_drop=0.0):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.query = nn.Linear(dim, dim, bias=bias)
self.key = nn.Linear(dim, dim, bias=bias)
self.value = nn.Linear(dim, dim, bias=bias)
self.attn_drop = nn.Dropout(attn_drop)
self.out = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(out_drop)
[docs]
def forward(self, x, mask=None):
B, N, C = x.shape
q = self.query(x)
k = self.key(x)
v = self.value(x)
q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
k = k.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
mask_h = mask.reshape(B, 1, N, 1)
mask_w = mask.reshape(B, 1, 1, N)
mask2d = mask_h * mask_w
attn = attn.masked_fill(mask2d == 0, -1e3)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.out(x)
x = self.out_drop(x)
return x
[docs]
class AddHashSpatialPositionEmbs(nn.Module):
"""Adds learnable hash-based spatial embeddings to the inputs."""
def __init__(self, spatial_pos_grid_size, dim):
super().__init__()
self.position_emb = nn.parameter.Parameter(
torch.randn(1, spatial_pos_grid_size * spatial_pos_grid_size, dim)
)
nn.init.normal_(self.position_emb, std=0.02)
[docs]
def forward(self, inputs, inputs_positions):
return inputs + self.position_emb.squeeze(0)[inputs_positions.long()]
[docs]
class AddScaleEmbs(nn.Module):
"""Adds learnable scale embeddings to the inputs."""
def __init__(self, num_scales, dim):
super().__init__()
self.scale_emb = nn.parameter.Parameter(torch.randn(num_scales, dim))
nn.init.normal_(self.scale_emb, std=0.02)
[docs]
def forward(self, inputs, inputs_scale_positions):
return inputs + self.scale_emb[inputs_scale_positions.long()]
@ARCH_REGISTRY.register()
[docs]
class MUSIQ(nn.Module):
"""
MUSIQ model architecture.
Args:
- patch_size (int): Size of the patches to extract from the images.
- num_class (int): Number of classes to predict.
- hidden_size (int): Size of the hidden layer in the transformer encoder.
- mlp_dim (int): Size of the feedforward layer in the transformer encoder.
- attention_dropout_rate (float): Dropout rate for the attention layer in the transformer encoder.
- dropout_rate (float): Dropout rate for the transformer encoder.
- num_heads (int): Number of attention heads in the transformer encoder.
- num_layers (int): Number of layers in the transformer encoder.
- num_scales (int): Number of scales to use in the transformer encoder.
- spatial_pos_grid_size (int): Size of the spatial position grid in the transformer encoder.
- use_scale_emb (bool): Whether to use scale embeddings in the transformer encoder.
- use_sinusoid_pos_emb (bool): Whether to use sinusoidal position embeddings in the transformer encoder.
- pretrained (bool or str): Whether to use a pretrained model. If str, specifies the path to the pretrained model.
- pretrained_model_path (str): Path to the pretrained model.
- longer_side_lengths (list): List of longer side lengths to use for multiscale evaluation.
- max_seq_len_from_original_res (int): Maximum sequence length to use for multiscale evaluation.
Attributes:
- conv_root (StdConv): Convolutional layer for the root of the network.
- gn_root (nn.GroupNorm): Group normalization layer for the root of the network.
- root_pool (nn.Sequential): Max pooling layer for the root of the network.
- block1 (Bottleneck): First bottleneck block in the network.
- embedding (nn.Linear): Linear layer for the transformer encoder input.
- transformer_encoder (TransformerEncoder): Transformer encoder.
- head (nn.Sequential or nn.Linear): Output layer of the network.
Methods:
forward(x, return_mos=True, return_dist=False): Forward pass of the network.
"""
def __init__(
self,
patch_size=32,
num_class=1,
hidden_size=384,
mlp_dim=1152,
attention_dropout_rate=0.0,
dropout_rate=0,
num_heads=6,
num_layers=14,
num_scales=3,
spatial_pos_grid_size=10,
use_scale_emb=True,
use_sinusoid_pos_emb=False,
pretrained=True,
pretrained_model_path=None,
# data opts
longer_side_lengths=[224, 384],
max_seq_len_from_original_res=-1,
):
super(MUSIQ, self).__init__()
resnet_token_dim = 64
self.patch_size = patch_size
self.data_preprocess_opts = {
'patch_size': patch_size,
'patch_stride': patch_size,
'hse_grid_size': spatial_pos_grid_size,
'longer_side_lengths': longer_side_lengths,
'max_seq_len_from_original_res': max_seq_len_from_original_res,
}
# set num_class to 10 if pretrained model used AVA dataset
# if not specified pretrained dataset, use AVA for default
if pretrained_model_path is None and pretrained:
url_key = 'ava' if isinstance(pretrained, bool) else pretrained
num_class = 10 if url_key == 'ava' else num_class
pretrained_model_path = default_model_urls[url_key]
self.conv_root = StdConv(3, resnet_token_dim, 7, 2, bias=False)
self.gn_root = nn.GroupNorm(32, resnet_token_dim, eps=1e-6)
self.root_pool = nn.Sequential(
nn.ReLU(True),
ExactPadding2d(3, 2, mode='same'),
nn.MaxPool2d(3, 2),
)
token_patch_size = patch_size // 4
self.block1 = Bottleneck(resnet_token_dim, resnet_token_dim * 4)
self.embedding = nn.Linear(
resnet_token_dim * 4 * token_patch_size**2, hidden_size
)
self.transformer_encoder = TransformerEncoder(
hidden_size,
mlp_dim,
attention_dropout_rate,
dropout_rate,
num_heads,
num_layers,
num_scales,
spatial_pos_grid_size,
use_scale_emb,
use_sinusoid_pos_emb,
)
if num_class > 1:
self.head = nn.Sequential(
nn.Linear(hidden_size, num_class),
nn.Softmax(dim=-1),
)
else:
self.head = nn.Linear(hidden_size, num_class)
if pretrained_model_path is not None:
load_pretrained_network(self, pretrained_model_path, True)
[docs]
def forward(self, x, return_mos=True, return_dist=False):
"""
Forward pass of the MUSIQ network.
Args:
x (torch.Tensor): Input tensor.
return_mos (bool): Whether to return the mean opinion score (MOS).
return_dist (bool): Whether to return the predicted distribution.
Returns:
torch.Tensor or tuple: If only one of return_mos and return_dist is True, returns a tensor. If both are True, returns a tuple of tensors.
"""
# normalize inputs to [-1, 1] as the official code
if not self.training:
x = (x - 0.5) * 2
x = get_multiscale_patches(x, **self.data_preprocess_opts)
assert len(x.shape) in [3, 4]
if len(x.shape) == 4:
b, num_crops, seq_len, dim = x.shape
x = x.reshape(b * num_crops, seq_len, dim)
else:
b, seq_len, dim = x.shape
num_crops = 1
inputs_spatial_positions = x[:, :, -3]
inputs_scale_positions = x[:, :, -2]
inputs_masks = x[:, :, -1].bool()
x = x[:, :, :-3]
x = x.reshape(-1, 3, self.patch_size, self.patch_size)
x = self.conv_root(x)
x = self.gn_root(x)
x = self.root_pool(x)
x = self.block1(x)
# to match tensorflow channel order
x = x.permute(0, 2, 3, 1)
x = x.reshape(b, seq_len, -1)
x = self.embedding(x)
x = self.transformer_encoder(
x, inputs_spatial_positions, inputs_scale_positions, inputs_masks
)
q = self.head(x[:, 0])
q = q.reshape(b, num_crops, -1)
q = q.mean(dim=1) # for multiple crops evaluation
mos = dist_to_mos(q)
return_list = []
if return_mos:
return_list.append(mos)
if return_dist:
return_list.append(q)
if len(return_list) > 1:
return return_list
else:
return return_list[0]