前言
☀️ 在低照度场景下进行目标检测任务,常存在图像RGB特征信息少、提取特征困难、目标识别和定位精度低等问题,给检测带来一定的难度。
🌻使用图像增强模块对原始图像进行画质提升,恢复各类图像信息,再使用目标检测网络对增强图像进行特定目标检测,有效提高检测的精确度。
⭐本专栏会介绍传统方法、Retinex、EnlightenGAN、SCI、Zero-DCE、IceNet、RRDNet、URetinex-Net等低照度图像增强算法。
👑完整代码已打包上传至资源→低照度图像增强代码汇总
目录
前言
🚀一、IceNet介绍
☀️1.1 IceNet简介
🚀二、IceNet网络结构及核心代码
☀️2.1 网络结构
☀️2.2 核心代码
🚀三、IceNet损失函数及核心代码
☀️3.1 Interactive brightness control loss—交互式亮度控制损失
☀️3.2 Entropy loss—熵损失
☀️3.3 Smoothness loss—平滑损失
☀️3.4 Total loss—总损失
🚀四、IceNet代码复现
☀️4.1 环境配置
☀️4.2 运行过程
☀️4.3 运行效果
🚀一、IceNet介绍
相关资料:
- 论文题目:《IceNet for Interactive Contrast Enhancement》(用于交互式对比度增强的IceNet)
- 原文地址:export.arxiv.org/pdf/2109.05838v2.pdf
- 论文详解:IEEE| IceNet《IceNet for Interactive Contrast Enhancement》论文超详细解读(翻译+精读)
☀️1.1 IceNet简介
本文提出了一种基于 CNN 的交互式对比度增强算法,称为 IceNet,该算法使用户能够根据自己的喜好轻松调整图像对比度。
具体来说,用户提供用于控制全局亮度的参数和两种类型的scribble来使图像中的局部区域变暗或变亮。然后,根据这些注释,IceNet 估计用于逐像素伽玛校正的伽玛图。最后,通过色彩恢复,得到增强后的图像。用户可以迭代地提供注释以获得满意的图像。
IceNet还能够自动生成个性化的增强图像,如果需要的话可以作为进一步调整的基础。
本文主要贡献
- 本文提出了第一个基于CNN的交互式CE算法,称为IceNet,它可以根据用户的偏好,自适应地生成增强的图像,也可以自动生成无需交互的图像。
- 本文用提出的三个可微损失函数训练IceNet,从而实现与用户的交互,并产生效果不错的增强图像。
- 本文通过各种实验结果,证明IceNet可以为用户提供满意的结果,明显优于传统算法。
🚀二、IceNet网络结构及核心代码
☀️2.1 网络结构
通过检查输入图像I,用户提供一个曝光等级η∈[0,1]来控制全局亮度和两种scribble类型(红色和蓝色的scribble分别表示用户想要使相应的局部区域变暗或变亮。步骤如下:
- 首先,在scribble图S中分别记为−1和1,其余像素赋值为0。
- 接着,将RGB彩色图像I转换到YCbCr空间,只调整亮度分量Y,同时保留色度分量。
- 然后,估计一张伽马图Γ,用于y的像素级伽马校正。
- 最后,通过颜色恢复,得到增强后的图像J。
☀️2.2 核心代码
import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np class IceNet(nn.Module): def __init__(self): super(IceNet, self).__init__() self.relu = nn.ReLU(inplace=True) # 7个卷积层用于提取特征 self.e_conv1 = nn.Conv2d(2,32,3,1,1,bias=True) self.e_conv2 = nn.Conv2d(32,32,3,1,1,bias=True) self.e_conv3 = nn.Conv2d(32,32,3,1,1,bias=True) self.e_conv4 = nn.Conv2d(32,32,3,1,1,bias=True) self.e_conv5 = nn.Conv2d(64,32,3,1,1,bias=True) self.e_conv6 = nn.Conv2d(64,32,3,1,1,bias=True) self.e_conv7 = nn.Conv2d(64,32,3,1,1,bias=True) # 两个全连接层 (fc1 和 fc2),用于生成自适应向量。 self.fc1 = nn.Linear(1, 32) self.fc2 = nn.Linear(32, 32) def forward(self, y, maps, e, lowlight=None, is_train=False): b, _, h, w = y.shape x_ = torch.cat([y, maps], 1) # y 和 maps 是输入的张量 # generate adaptive vector according to eta W = self.relu(self.fc1(e)) # e是一个向量,用于生成自适应增强的参数。 W = self.fc2(W) # feature extractor x1 = self.relu(self.e_conv1(x_)) x2 = self.relu(self.e_conv2(x1)) x3 = self.relu(self.e_conv3(x2)) x4 = self.relu(self.e_conv4(x3)) x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1))) x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1))) x_r = self.relu(self.e_conv7(torch.cat([x1,x6],1))) # 使用自适应增强方法 AGEB 对 x_r 进行增强。 x_r = F.conv2d(x_r.view(1, b * 32, h, w), W.view(b, 32, 1, 1), groups=b) x_r = torch.sigmoid(x_r).view(b, 1, h, w) * 10 # 进行 gamma 校正,得到增强后的图像 enhanced_Y。 enhanced_Y = torch.pow(y,x_r) if is_train: return enhanced_Y, x_r # 如果处于训练模式,返回增强后的图像和增强参数 x_r; else: # color restoration enhanced_image = torch.clip(enhanced_Y*(lowlight/y), 0, 1) return enhanced_image # 否则,对增强后的图像进行颜色还原,得到 enhanced_image,并返回。
步骤如下:
- 首先,对输入的 y 和 maps 进行concat,生成新的输入 x_。
- 然后,通过全连接层计算自适应向量 W,并使用 ReLU 激活函数。
- 接下来,通过多个卷积层提取特征,得到 x_r。再使用 AGEB 方法对 x_r 进行自适应增强。
- 最后,对增强后的图像进行 gamma 校正,得到 enhanced_Y。
- 如果是训练模式,则返回增强后的图像和增强参数 x_r;否则,对增强后的图像进行颜色还原,得到 enhanced_image,并返回。
🚀三、IceNet损失函数及核心代码
这篇文章的贡献也是,主要在损失函数上。
☀️3.1 Interactive brightness control loss—交互式亮度控制损失
- 首先,在输入亮度Y上加上scribble的S
- 接着,将Y归一化到[0,1]的范围,得到
- 最后,使用双边伽玛调整方案来提高细节在暗和亮区域的可见性
- 式(7):暗区增强
- 式(8):亮区增强
- 式(9):将这两种结果结合起来,同时保留暗区和亮区细节
☀️3.2 Entropy loss—熵损失
设计目的
最大熵是通过均匀分布实现的,因此通过均衡输出图像的直方图,采用熵损失来增加全局对比度。
直方图由于其不可微性,不能直接使用。所以本文设计了一个软直方图。
图4(a)显示了σ = 5、10或20时的软映射函数。
- σ控制映射函数的宽度和高度之间的权衡。随着σ的增大,映射函数变窄变高。
- 在本文中,设定σ = 10。
通过对所有像素的贡献求和,得到软直方图
定义为熵的倒数
☀️3.3 Smoothness loss—平滑损失
引入目的: 为了促进式(1)中伽马图Γ中相邻值之间的平滑变化
☀️3.4 Total loss—总损失
总损失定义为三种损失的加权和
- 首先,使IceNet能够控制全局和局部亮度。
- 其次,促进一个平滑的直方图的形成,这可以增加整体对比度。
- 第三,平滑伽马图。
图4(b) ~ (e)说明了每一次损失的效果
代码如下:
import torch import torch.nn as nn import torch.nn.functional as F import math from torchvision.models.vgg import vgg16 import numpy as np from numpy.testing import assert_almost_equal class L_ent(nn.Module): def __init__(self, bins, min, max, sigma): super(L_ent, self).__init__() self.bins = bins # 表示直方图的分箱数量 self.min = min # 表示直方图的最小值 self.max = max # 表示直方图的最大值 self.sigma = sigma # 超参数 self.delta = float(max - min) / float(bins) # 计算直方图分箱的间隔大小 self.centers = float(min) + self.delta * (torch.arange(bins).float().cuda() + 0.5) # 计算直方图的中心 def forward(self, y): b, _, h, w = y.shape # 获取输入张量 y 的形状信息 y = y.reshape(b, 1, -1) # -1 表示自动推断维度。 c = self.centers.reshape(1, -1, 1).repeat(b, 1, 1) # 计算直方图的中心c x = y - c # 计算差值 x,即每个像素值与直方图中心的距离。 # (x + self.delta/2) 和 (x - self.delta/2) 分别对应直方图箱的右边界和左边界。 x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2)) hist = torch.sum(x, 2) p = hist / (h * w) + 1e-6 # 计算直方图 hist,并归一化为概率 p。 d = torch.sum((-p * torch.log(p))) # 计算交叉熵损失 d。 return 1/d class L_int(nn.Module): def __init__(self): super(L_int, self).__init__() def forward(self, x, mean_val, labels): b,c,h,w = x.shape x = torch.mean(x,1,keepdim=True) d = torch.mean(torch.pow(x- labels,2)) # 计算平均值与标签之间的均方误差 d return d class L_smo(nn.Module): def __init__(self): super(L_smo,self).__init__() def forward(self,x): batch_size = x.size()[0] h_x = x.size()[2] w_x = x.size()[3] count_h = (x.size()[2]-1) * x.size()[3] count_w = x.size()[2] * (x.size()[3] - 1) h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() # h_tv 表示水平方向上的总变化数 w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() # w_tv 表示垂直方向上的总变化数。 return 2*(h_tv/count_h+w_tv/count_w)/batch_size # 计算总的平滑损失值,并除以批量大小,得到最终的损失值。
🚀四、IceNet代码复现
☀️4.1 环境配置
- Python 3.7
- Pytorch 1.0.0
- opencv
- torchvision 0.2.1
- cuda 10.0
☀️4.2 运行过程
这个也是运行比较简单,配好环境就行 。不再过多叙述~
☀️4.3 运行效果
- 最后,使用双边伽玛调整方案来提高细节在暗和亮区域的可见性