【低照度图像增强系列(6)】IceNet算法详解与代码实现(IEEE)

前言 

☀️ 在低照度场景下进行目标检测任务,常存在图像RGB特征信息少、提取特征困难、目标识别和定位精度低等问题,给检测带来一定的难度。

     🌻使用图像增强模块对原始图像进行画质提升,恢复各类图像信息,再使用目标检测网络对增强图像进行特定目标检测,有效提高检测的精确度。

      ⭐本专栏会介绍传统方法、Retinex、EnlightenGAN、SCI、Zero-DCE、IceNet、RRDNet、URetinex-Net等低照度图像增强算法。

👑完整代码已打包上传至资源→低照度图像增强代码汇总

目录

前言 

🚀一、IceNet介绍  

☀️1.1 IceNet简介  

🚀二、IceNet网络结构及核心代码

☀️2.1 网络结构 

☀️2.2 核心代码 

🚀三、IceNet损失函数及核心代码

☀️3.1 Interactive brightness control loss—交互式亮度控制损失

☀️3.2  Entropy loss—熵损失

☀️3.3 Smoothness loss—平滑损失

☀️3.4 Total loss—总损失

🚀四、IceNet代码复现

☀️4.1 环境配置

☀️4.2 运行过程

☀️4.3 运行效果

🚀一、IceNet介绍  

相关资料: 

  • 论文题目:《IceNet for Interactive Contrast Enhancement》(用于交互式对比度增强的IceNet)
  • 原文地址:export.arxiv.org/pdf/2109.05838v2.pdf
  • 论文详解:IEEE| IceNet《IceNet for Interactive Contrast Enhancement》论文超详细解读(翻译+精读)

☀️1.1 IceNet简介  

本文提出了一种基于 CNN 的交互式对比度增强算法,称为 IceNet,该算法使用户能够根据自己的喜好轻松调整图像对比度。

具体来说,用户提供用于控制全局亮度的参数和两种类型的scribble来使图像中的局部区域变暗或变亮。然后,根据这些注释,IceNet 估计用于逐像素伽玛校正的伽玛图。最后,通过色彩恢复,得到增强后的图像。用户可以迭代地提供注释以获得满意的图像。

IceNet还能够自动生成个性化的增强图像,如果需要的话可以作为进一步调整的基础。

本文主要贡献

  • 本文提出了第一个基于CNN的交互式CE算法,称为IceNet,它可以根据用户的偏好,自适应地生成增强的图像,也可以自动生成无需交互的图像。
  • 本文用提出的三个可微损失函数训练IceNet,从而实现与用户的交互,并产生效果不错的增强图像。
  • 本文通过各种实验结果,证明IceNet可以为用户提供满意的结果,明显优于传统算法。

    🚀二、IceNet网络结构及核心代码

    ☀️2.1 网络结构 

    通过检查输入图像I,用户提供一个曝光等级η∈[0,1]来控制全局亮度和两种scribble类型(红色和蓝色的scribble分别表示用户想要使相应的局部区域变暗或变亮。步骤如下:

    • 首先,在scribble图S中分别记为−1和1,其余像素赋值为0。
    • 接着,将RGB彩色图像I转换到YCbCr空间,只调整亮度分量Y,同时保留色度分量。
    • 然后,估计一张伽马图Γ,用于y的像素级伽马校正。
    • 最后,通过颜色恢复,得到增强后的图像J。

      ☀️2.2 核心代码 

      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      import math
      import numpy as np
      class IceNet(nn.Module):
      	def __init__(self):
      		super(IceNet, self).__init__()
      		self.relu = nn.ReLU(inplace=True)
      		# 7个卷积层用于提取特征
      		self.e_conv1 = nn.Conv2d(2,32,3,1,1,bias=True) 
      		self.e_conv2 = nn.Conv2d(32,32,3,1,1,bias=True) 
      		self.e_conv3 = nn.Conv2d(32,32,3,1,1,bias=True) 
      		self.e_conv4 = nn.Conv2d(32,32,3,1,1,bias=True) 
      		self.e_conv5 = nn.Conv2d(64,32,3,1,1,bias=True) 
      		self.e_conv6 = nn.Conv2d(64,32,3,1,1,bias=True) 
      		self.e_conv7 = nn.Conv2d(64,32,3,1,1,bias=True)
              # 两个全连接层 (fc1 和 fc2),用于生成自适应向量。
      		self.fc1 = nn.Linear(1, 32)
      		self.fc2 = nn.Linear(32, 32)
      	def forward(self, y, maps, e, lowlight=None, is_train=False):
      		b, _, h, w = y.shape
      		x_ = torch.cat([y, maps], 1) # y 和 maps 是输入的张量
      		# generate adaptive vector according to eta
      		W = self.relu(self.fc1(e)) # e是一个向量,用于生成自适应增强的参数。
      		W = self.fc2(W)
      		# feature extractor
      		x1 = self.relu(self.e_conv1(x_))
      		x2 = self.relu(self.e_conv2(x1))
      		x3 = self.relu(self.e_conv3(x2))
      		x4 = self.relu(self.e_conv4(x3))
      		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
      		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
      		x_r = self.relu(self.e_conv7(torch.cat([x1,x6],1)))
      		# 使用自适应增强方法 AGEB 对 x_r 进行增强。
      		x_r = F.conv2d(x_r.view(1, b * 32, h, w),
      				W.view(b, 32, 1, 1), groups=b)
      		x_r = torch.sigmoid(x_r).view(b, 1, h, w) * 10
      		# 进行 gamma 校正,得到增强后的图像 enhanced_Y。
      		enhanced_Y = torch.pow(y,x_r)
      		if is_train:
      			return enhanced_Y, x_r # 如果处于训练模式,返回增强后的图像和增强参数 x_r;
      		else:
      			# color restoration
      			enhanced_image = torch.clip(enhanced_Y*(lowlight/y), 0, 1)
      			return enhanced_image # 否则,对增强后的图像进行颜色还原,得到 enhanced_image,并返回。

       步骤如下:

      • 首先,对输入的 y 和 maps 进行concat,生成新的输入 x_。
      • 然后,通过全连接层计算自适应向量 W,并使用 ReLU 激活函数。
      • 接下来,通过多个卷积层提取特征,得到 x_r。再使用 AGEB 方法对 x_r 进行自适应增强。
      • 最后,对增强后的图像进行 gamma 校正,得到 enhanced_Y。
      • 如果是训练模式,则返回增强后的图像和增强参数 x_r;否则,对增强后的图像进行颜色还原,得到 enhanced_image,并返回。

        🚀三、IceNet损失函数及核心代码

        这篇文章的贡献也是,主要在损失函数上。

        ☀️3.1 Interactive brightness control loss—交互式亮度控制损失

        • 首先,在输入亮度Y上加上scribble的S
        • 接着,将Y归一化到[0,1]的范围,得到

          • 最后,使用双边伽玛调整方案来提高细节在暗和亮区域的可见性

            • 式(7):暗区增强
            • 式(8):亮区增强
            • 式(9):将这两种结果结合起来,同时保留暗区和亮区细节  

              ☀️3.2  Entropy loss—熵损失

              设计目的

              最大熵是通过均匀分布实现的,因此通过均衡输出图像的直方图,采用熵损失来增加全局对比度。

              直方图由于其不可微性,不能直接使用。所以本文设计了一个软直方图。

              图4(a)显示了σ = 5、10或20时的软映射函数。

              • σ控制映射函数的宽度和高度之间的权衡。随着σ的增大,映射函数变窄变高。
              • 在本文中,设定σ = 10。

                通过对所有像素的贡献求和,得到软直方图

                定义为熵的倒数


                ☀️3.3 Smoothness loss—平滑损失

                引入目的: 为了促进式(1)中伽马图Γ中相邻值之间的平滑变化


                ☀️3.4 Total loss—总损失

                总损失定义为三种损失的加权和

                • 首先,使IceNet能够控制全局和局部亮度。
                • 其次,促进一个平滑的直方图的形成,这可以增加整体对比度。
                • 第三,平滑伽马图。

                  图4(b) ~ (e)说明了每一次损失的效果

                   代码如下:

                  import torch
                  import torch.nn as nn
                  import torch.nn.functional as F
                  import math
                  from torchvision.models.vgg import vgg16
                  import numpy as np
                  from numpy.testing import assert_almost_equal
                  class L_ent(nn.Module):
                      def __init__(self, bins, min, max, sigma):
                          super(L_ent, self).__init__()
                          self.bins = bins # 表示直方图的分箱数量
                          self.min = min # 表示直方图的最小值
                          self.max = max # 表示直方图的最大值
                          self.sigma = sigma # 超参数
                          self.delta = float(max - min) / float(bins) # 计算直方图分箱的间隔大小
                          self.centers = float(min) + self.delta * (torch.arange(bins).float().cuda() + 0.5) # 计算直方图的中心
                      def forward(self, y):
                          b, _, h, w = y.shape # 获取输入张量 y 的形状信息
                          y = y.reshape(b, 1, -1) # -1 表示自动推断维度。
                          c = self.centers.reshape(1, -1, 1).repeat(b, 1, 1) # 计算直方图的中心c
                          x = y - c # 计算差值 x,即每个像素值与直方图中心的距离。
                          # (x + self.delta/2) 和 (x - self.delta/2) 分别对应直方图箱的右边界和左边界。
                          x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2))
                          hist = torch.sum(x, 2)
                          p = hist / (h * w) + 1e-6 # 计算直方图 hist,并归一化为概率 p。
                          d = torch.sum((-p * torch.log(p))) # 计算交叉熵损失 d。
                          return 1/d
                  class L_int(nn.Module):
                      def __init__(self):
                          super(L_int, self).__init__()
                      def forward(self, x, mean_val, labels):
                          b,c,h,w = x.shape
                          x = torch.mean(x,1,keepdim=True)
                          d = torch.mean(torch.pow(x- labels,2)) # 计算平均值与标签之间的均方误差 d
                          return d
                          
                  class L_smo(nn.Module):
                      def __init__(self):
                          super(L_smo,self).__init__()
                      def forward(self,x):
                          batch_size = x.size()[0]
                          h_x = x.size()[2]
                          w_x = x.size()[3]
                          count_h =  (x.size()[2]-1) * x.size()[3]
                          count_w = x.size()[2] * (x.size()[3] - 1)
                          h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() # h_tv 表示水平方向上的总变化数
                          w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() # w_tv 表示垂直方向上的总变化数。
                          return 2*(h_tv/count_h+w_tv/count_w)/batch_size # 计算总的平滑损失值,并除以批量大小,得到最终的损失值。
                  

                  🚀四、IceNet代码复现

                  ☀️4.1 环境配置

                  • Python 3.7
                  • Pytorch 1.0.0
                  • opencv
                  • torchvision 0.2.1
                  • cuda 10.0

                    ☀️4.2 运行过程

                    这个也是运行比较简单,配好环境就行 。不再过多叙述~


                    ☀️4.3 运行效果