pyiqa.archs.msswd_arch ====================== .. py:module:: pyiqa.archs.msswd_arch .. autoapi-nested-parse:: Perceptual color difference metric, MS-SWD. @inproceedings{he2024ms-swd, title={Multiscale Sliced {Wasserstein} Distances as Perceptual Color Difference Measures}, author={He, Jiaqi and Wang, Zhihua and Wang, Leon and Liu, Tsein-I and Fang, Yuming and Sun, Qilin and Ma, Kede}, booktitle={European Conference on Computer Vision}, pages={1--18}, year={2024}, url={http://arxiv.org/abs/2407.10181} } Reference: - Official github: https://github.com/real-hjq/MS-SWD Module Contents --------------- .. py:function:: color_space_transform(input_color, fromSpace2toSpace) Transform color tensors between supported color spaces. :param input_color: Color tensor with shape ``(N, C, H, W)``. :type input_color: torch.Tensor :param fromSpace2toSpace: Conversion key, for example ``'srgb2lab'``. :type fromSpace2toSpace: str :returns: Transformed tensor with shape ``(N, C, H, W)``. :rtype: torch.Tensor :raises ValueError: If the conversion key is not defined. .. py:class:: MS_SWD_learned(resize_input: bool = True, pretrained: bool = True, pretrained_model_path: str = None, **kwargs) Bases: :py:obj:`torch.nn.Module` MS-SWD perceptual color difference metric. :param resize_input: Whether to resize inputs with short side larger than ``256`` before scoring. :type resize_input: bool :param pretrained: Whether to load pretrained weights. :type pretrained: bool :param pretrained_model_path: Optional local checkpoint path. :type pretrained_model_path: str | None :param \*\*kwargs: Reserved compatibility arguments. .. py:method:: preprocess_img(x) Optionally resize image batch before feature extraction. .. py:method:: forward_once(x) Encode one image batch into sorted SWD feature representation. .. py:method:: forward(x, y) Compute MS-SWD distance. :param x: Distorted image tensor with shape ``(N, 3, H, W)``. :type x: torch.Tensor :param y: Reference image tensor with shape ``(N, 3, H, W)``. :type y: torch.Tensor :returns: Distance scores with shape ``(N,)``. :rtype: torch.Tensor