对抗生成网络GAN系列——原理及源码解析 写在前面
在前面,我已经介绍过好几篇有关GAN的文章,链接如下:
这篇文章我将来为大家介绍,论文名为:Semi- via 。这篇文章同样是实现缺陷检测的,因此在阅读本文之前建议你对使用GAN网络实现缺陷检测有一定的了解,可以参考上文链接中的和。
准备好了吗,嘟嘟嘟,开始发车。
原理解析
【阅读此部分前建议对GAN的原理及GAN在缺陷检测上的应用有所了解,详情点击写在前面中的链接查看,本篇文章我不会再介绍GAN的一些先验知识。】
结构
这部分为大家介绍的原理,其实我们一起来看下图就足够了:
图1 结构图
我们还是先来对上图中的结构做一些解释。从直观的颜色上来看,我们可以分成两类,一类是红色的结构,一类是蓝色的结构。主要就是降维的作用啦,如将一张张图片数据压缩成一个个潜在向量;相反,就是升维的作用,如将一个个潜在向量重建成一张张图片。按照论文描述的结构来分,可以分成三个子结构,分别为生成器网络G,编码器网络E和判别器网络D。下面分别来介绍介绍这三个子结构:
损失函数
的损失函数分为两部分,第一部分是生成器损失,第二部分为判别器损失,下面我们分别来进行介绍:
生成器总的损失是上述三种损失的加权和,如下:
L = w a d v L a d v + w c o n L c o n + w e n c L e n c L=w_{adv}L_{adv}+w_{con}L_{con}+w_{enc}L_{enc} L=wadvLadv+wconLcon+wencLenc
在论文提供的源码中,默认 w c o n = 50 , w a d v = w e n c = 1 w_{con}=50,w_{adv}=w_{enc}=1 wcon=50,wadv=wenc=1。
测试阶段
在上一小节,为大家介绍了的损失函数,这是在测试阶段使用的。针对的是异常检测任务,在测试阶段我们会对输入的数据进行评分,根据评分的结果来判定输入是否异常。在中使用的评分函数就是我们上一小节介绍的 Loss,对于一个测试数据x,用 A ( x ) A(x) A(x)表示其异常得分,则:
A ( x ) = ∣ ∣ G E ( x ) − E ( G ( x ) ) ∣ ∣ 2 A(x)=||G_E(x)-E(G(x))||_2 A(x)=∣∣GE(x)−E(G(x))∣∣2
这里大家需要注意以下,论文中 A ( x ) A(x) A(x)的表达式使用的是L1范数,但是从我阅读论文提供的源码来看,代码中使用的是L2范数。这里保持和源码一致,使用L2范数。代码中关于此部分的描述如下:
# latent_i表示G_E(x),latent_o表示E(G(x))。torch.pow(m,2)=m^2
error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)
源码解析
这里直接使用论文中提供的源码地址:源码
模型搭建
其实通过我前文的讲解,不知道大家能否感受到模型其实是不复杂的。需要注意的是在介绍结构时我们将模型分为了三个子结构,分别为生成器网络G、编码器网络E、判别器网络D。但是在代码中我们将生成器网络G和编码器网络E合并在一块儿了,也称为生成器网络G。
下面我给出这部分的代码,大家注意一下这里面的超参数比较多,为了方便大家阅读,我把这里用到超参数的整理出来,如下图所示:
""" Network architectures.
"""
# pylint: disable=W0221,W0622,C0103,R0913
##
import torch
import torch.nn as nn
import torch.nn.parallel
from options import Options
##
def weights_init(mod):
"""
Custom weights initialization called on netG, netD and netE
:param m:
:return:
"""
classname = mod.__class__.__name__
if classname.find('Conv') != -1:
mod.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
mod.weight.data.normal_(1.0, 0.02)
mod.bias.data.fill_(0)
###
class Encoder(nn.Module):
"""
DCGAN ENCODER NETWORK
"""
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0, add_final_conv=True):
super(Encoder, self).__init__()
self.ngpu = ngpu
assert isize % 16 == 0, "isize has to be a multiple of 16"
main = nn.Sequential()
# input is nc x isize x isize
main.add_module('initial-conv-{0}-{1}'.format(nc, ndf),
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
main.add_module('initial-relu-{0}'.format(ndf),
nn.LeakyReLU(0.2, inplace=True))
csize, cndf = isize / 2, ndf # csize=16,cndf=64
# Extra layers
for t in range(n_extra_layers):
main.add_module('extra-layers-{0}-{1}-conv'.format(t, cndf),
nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False))
main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cndf),
nn.BatchNorm2d(cndf))
main.add_module('extra-layers-{0}-{1}-relu'.format(t, cndf),
nn.LeakyReLU(0.2, inplace=True))
while csize > 4:
in_feat = cndf

out_feat = cndf * 2
main.add_module('pyramid-{0}-{1}-conv'.format(in_feat, out_feat),
nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))
main.add_module('pyramid-{0}-batchnorm'.format(out_feat),
nn.BatchNorm2d(out_feat))
main.add_module('pyramid-{0}-relu'.format(out_feat),
nn.LeakyReLU(0.2, inplace=True))
cndf = cndf * 2
csize = csize / 2
# state size. K x 4 x 4
if add_final_conv:
main.add_module('final-{0}-{1}-conv'.format(cndf, 1),
nn.Conv2d(cndf, nz, 4, 1, 0, bias=False))
self.main = main
def forward(self, input):
if self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
##
class Decoder(nn.Module):
"""
DCGAN DECODER NETWORK
"""
def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0):
super(Decoder, self).__init__()
self.ngpu = ngpu
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf // 2, 4 #cngf=32 ,tisize=4
while tisize != isize:
cngf = cngf * 2
tisize = tisize * 2
main = nn.Sequential()
# input is Z, going into a convolution
main.add_module('initial-{0}-{1}-convt'.format(nz, cngf),
nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False))
main.add_module('initial-{0}-batchnorm'.format(cngf),
nn.BatchNorm2d(cngf))
main.add_module('initial-{0}-relu'.format(cngf),
nn.ReLU(True))
csize, _ = 4, cngf
while csize < isize // 2:
main.add_module('pyramid-{0}-{1}-convt'.format(cngf, cngf // 2),
nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False))
main.add_module('pyramid-{0}-batchnorm'.format(cngf // 2),
nn.BatchNorm2d(cngf // 2))
main.add_module('pyramid-{0}-relu'.format(cngf // 2),
nn.ReLU(True))
cngf = cngf // 2
csize = csize * 2
# Extra layers
for t in range(n_extra_layers):
main.add_module('extra-layers-{0}-{1}-conv'.format(t, cngf),
nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False))
main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cngf),
nn.BatchNorm2d(cngf))
main.add_module('extra-layers-{0}-{1}-relu'.format(t, cngf),
nn.ReLU(True))
main.add_module('final-{0}-{1}-convt'.format(cngf, nc),
nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
main.add_module('final-{0}-tanh'.format(nc),
nn.Tanh())
self.main = main
def forward(self, input):
if self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
## 判别器网络结构

class NetD(nn.Module):
"""
DISCRIMINATOR NETWORK
"""
def __init__(self, opt):
super(NetD, self).__init__()
model = Encoder(opt.isize, 1, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
layers = list(model.main.children())
self.features = nn.Sequential(*layers[:-1])
self.classifier = nn.Sequential(layers[-1])
self.classifier.add_module('Sigmoid', nn.Sigmoid())
def forward(self, x):
features = self.features(x)
features = features
classifier = self.classifier(features)
classifier = classifier.view(-1, 1).squeeze(1)
return classifier, features
## 生成器网络结构
class NetG(nn.Module):
"""
GENERATOR NETWORK
"""
def __init__(self, opt):
super(NetG, self).__init__()
self.encoder1 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
self.decoder = Decoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
self.encoder2 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
def forward(self, x):
latent_i = self.encoder1(x)
gen_imag = self.decoder(latent_i)
latent_o = self.encoder2(gen_imag)
return gen_imag, latent_i, latent_o
损失函数
我们在理论部分已经介绍了的损失函数,那么在代码上它们都是一一对应的,实现起来也很简单,如下:
## 定义L1 Loss
def l1_loss(input, target):
return torch.mean(torch.abs(input - target))
## 定义L2 Loss
def l2_loss(input, target, size_average=True):
if size_average:
return torch.mean(torch.pow((input-target), 2))
else:
return torch.pow((input-target), 2)
self.l_adv = l2_loss
self.l_con = nn.L1Loss()
self.l_enc = l2_loss
self.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1])
self.err_g_con = self.l_con(self.fake, self.input)
self.err_g_enc = self.l_enc(self.latent_o, self.latent_i)
self.err_g = self.err_g_adv * self.opt.w_adv +
self.err_g_con * self.opt.w_con +
self.err_g_enc * self.opt.w_enc
上述代码为生成器损失函数代码,判别器的损失函数代码已经在理论部分为大家介绍了,这里就不在赘述了。
小结
这里我并没有很详细的为大家解读代码,但是把一些关键的部分都给大家介绍了。会了这些其实你完全可以自己实现一个网络,或者对我之前在中的代码稍加改造也可以达到一样的效果。论文中提供的源码感兴趣的大家可以自己去调试一下,代码量也不算多,但有的地方理解起来也有一定的困难,总之大家加油吧!!!
参考链接
: Semi- via
异常检测的经典之作|ACCV 2018
如若文章对你有所帮助,那就
转载:
323AI导航网发布