生成对抗网络‌ 对抗生成网络GAN系列——GANomaly原理及源码解析

默认分类5小时前发布 admin
4,484 0
星河超算AI数字人

对抗生成网络GAN系列——原理及源码解析 写在前面

​ 在前面,我已经介绍过好几篇有关GAN的文章,链接如下:

这篇文章我将来为大家介绍,论文名为:Semi- via 。这篇文章同样是实现缺陷检测的,因此在阅读本文之前建议你对使用GAN网络实现缺陷检测有一定的了解,可以参考上文链接中的和。

准备好了吗,嘟嘟嘟,开始发车。

原理解析

【阅读此部分前建议对GAN的原理及GAN在缺陷检测上的应用有所了解,详情点击写在前面中的链接查看,本篇文章我不会再介绍GAN的一些先验知识。】

结构

​ 这部分为大家介绍的原理,其实我们一起来看下图就足够了:

图1 结构图

我们还是先来对上图中的结构做一些解释。从直观的颜色上来看,我们可以分成两类,一类是红色的结构,一类是蓝色的结构。主要就是降维的作用啦,如将一张张图片数据压缩成一个个潜在向量;相反,就是升维的作用,如将一个个潜在向量重建成一张张图片。按照论文描述的结构来分,可以分成三个子结构,分别为生成器网络G,编码器网络E和判别器网络D。下面分别来介绍介绍这三个子结构:

损失函数

的损失函数分为两部分,第一部分是生成器损失,第二部分为判别器损失,下面我们分别来进行介绍:

生成器总的损失是上述三种损失的加权和,如下:

L = w a d v L a d v + w c o n L c o n + w e n c L e n c L=w_{adv}L_{adv}+w_{con}L_{con}+w_{enc}L_{enc} L=wadv​Ladv​+wcon​Lcon​+wenc​Lenc​

在论文提供的源码中,默认 w c o n = 50 , w a d v = w e n c = 1 w_{con}=50,w_{adv}=w_{enc}=1 wcon​=50,wadv​=wenc​=1。

测试阶段

在上一小节,为大家介绍了的损失函数,这是在测试阶段使用的。针对的是异常检测任务,在测试阶段我们会对输入的数据进行评分,根据评分的结果来判定输入是否异常。在中使用的评分函数就是我们上一小节介绍的 Loss,对于一个测试数据x,用 A ( x ) A(x) A(x)表示其异常得分,则:

​ A ( x ) = ∣ ∣ G E ( x ) − E ( G ( x ) ) ∣ ∣ 2 A(x)=||G_E(x)-E(G(x))||_2 A(x)=∣∣GE​(x)−E(G(x))∣∣2​

这里大家需要注意以下,论文中 A ( x ) A(x) A(x)的表达式使用的是L1范数,但是从我阅读论文提供的源码来看,代码中使用的是L2范数。这里保持和源码一致,使用L2范数。代码中关于此部分的描述如下:

# latent_i表示G_E(x),latent_o表示E(G(x))。torch.pow(m,2)=m^2
error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)

源码解析

​ 这里直接使用论文中提供的源码地址:源码

模型搭建

​  其实通过我前文的讲解,不知道大家能否感受到模型其实是不复杂的。需要注意的是在介绍结构时我们将模型分为了三个子结构,分别为生成器网络G、编码器网络E、判别器网络D。但是在代码中我们将生成器网络G和编码器网络E合并在一块儿了,也称为生成器网络G。

下面我给出这部分的代码,大家注意一下这里面的超参数比较多,为了方便大家阅读,我把这里用到超参数的整理出来,如下图所示:

""" Network architectures.
"""
# pylint: disable=W0221,W0622,C0103,R0913
##
import torch
import torch.nn as nn
import torch.nn.parallel
from options import Options
##
def weights_init(mod):
    """
    Custom weights initialization called on netG, netD and netE
    :param m:
    :return:
    """
    classname = mod.__class__.__name__
    if classname.find('Conv') != -1:
        mod.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        mod.weight.data.normal_(1.0, 0.02)
        mod.bias.data.fill_(0)
###
class Encoder(nn.Module):
    """
    DCGAN ENCODER NETWORK
    """
    def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0, add_final_conv=True):
        super(Encoder, self).__init__()
        self.ngpu = ngpu
        assert isize % 16 == 0, "isize has to be a multiple of 16"
        main = nn.Sequential()
        # input is nc x isize x isize
        main.add_module('initial-conv-{0}-{1}'.format(nc, ndf),
                        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
        main.add_module('initial-relu-{0}'.format(ndf),
                        nn.LeakyReLU(0.2, inplace=True))
        csize, cndf = isize / 2, ndf     # csize=16,cndf=64
        # Extra layers
        for t in range(n_extra_layers):
            main.add_module('extra-layers-{0}-{1}-conv'.format(t, cndf),
                            nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False))
            main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cndf),
                            nn.BatchNorm2d(cndf))
            main.add_module('extra-layers-{0}-{1}-relu'.format(t, cndf),
                            nn.LeakyReLU(0.2, inplace=True))
        while csize > 4:
            in_feat = cndf

生成对抗网络‌ 对抗生成网络GAN系列——GANomaly原理及源码解析

out_feat = cndf * 2 main.add_module('pyramid-{0}-{1}-conv'.format(in_feat, out_feat), nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) main.add_module('pyramid-{0}-batchnorm'.format(out_feat), nn.BatchNorm2d(out_feat)) main.add_module('pyramid-{0}-relu'.format(out_feat), nn.LeakyReLU(0.2, inplace=True)) cndf = cndf * 2 csize = csize / 2 # state size. K x 4 x 4 if add_final_conv: main.add_module('final-{0}-{1}-conv'.format(cndf, 1), nn.Conv2d(cndf, nz, 4, 1, 0, bias=False)) self.main = main def forward(self, input): if self.ngpu > 1: output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) else: output = self.main(input) return output ## class Decoder(nn.Module): """ DCGAN DECODER NETWORK """ def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): super(Decoder, self).__init__() self.ngpu = ngpu assert isize % 16 == 0, "isize has to be a multiple of 16" cngf, tisize = ngf // 2, 4 #cngf=32 ,tisize=4 while tisize != isize: cngf = cngf * 2 tisize = tisize * 2 main = nn.Sequential() # input is Z, going into a convolution main.add_module('initial-{0}-{1}-convt'.format(nz, cngf), nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) main.add_module('initial-{0}-batchnorm'.format(cngf), nn.BatchNorm2d(cngf)) main.add_module('initial-{0}-relu'.format(cngf), nn.ReLU(True)) csize, _ = 4, cngf while csize < isize // 2: main.add_module('pyramid-{0}-{1}-convt'.format(cngf, cngf // 2), nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False)) main.add_module('pyramid-{0}-batchnorm'.format(cngf // 2), nn.BatchNorm2d(cngf // 2)) main.add_module('pyramid-{0}-relu'.format(cngf // 2), nn.ReLU(True)) cngf = cngf // 2 csize = csize * 2 # Extra layers for t in range(n_extra_layers): main.add_module('extra-layers-{0}-{1}-conv'.format(t, cngf), nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False)) main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cngf), nn.BatchNorm2d(cngf)) main.add_module('extra-layers-{0}-{1}-relu'.format(t, cngf), nn.ReLU(True)) main.add_module('final-{0}-{1}-convt'.format(cngf, nc), nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) main.add_module('final-{0}-tanh'.format(nc), nn.Tanh()) self.main = main def forward(self, input): if self.ngpu > 1: output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) else: output = self.main(input) return output ## 判别器网络结构

生成对抗网络‌ 对抗生成网络GAN系列——GANomaly原理及源码解析

class NetD(nn.Module): """ DISCRIMINATOR NETWORK """ def __init__(self, opt): super(NetD, self).__init__() model = Encoder(opt.isize, 1, opt.nc, opt.ngf, opt.ngpu, opt.extralayers) layers = list(model.main.children()) self.features = nn.Sequential(*layers[:-1]) self.classifier = nn.Sequential(layers[-1]) self.classifier.add_module('Sigmoid', nn.Sigmoid()) def forward(self, x): features = self.features(x) features = features classifier = self.classifier(features) classifier = classifier.view(-1, 1).squeeze(1) return classifier, features ## 生成器网络结构 class NetG(nn.Module): """ GENERATOR NETWORK """ def __init__(self, opt): super(NetG, self).__init__() self.encoder1 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers) self.decoder = Decoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers) self.encoder2 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers) def forward(self, x): latent_i = self.encoder1(x) gen_imag = self.decoder(latent_i) latent_o = self.encoder2(gen_imag) return gen_imag, latent_i, latent_o

损失函数

​  我们在理论部分已经介绍了的损失函数,那么在代码上它们都是一一对应的,实现起来也很简单,如下:

## 定义L1 Loss
def l1_loss(input, target):
    return torch.mean(torch.abs(input - target))
## 定义L2 Loss
def l2_loss(input, target, size_average=True):
    if size_average:
        return torch.mean(torch.pow((input-target), 2))
    else:
        return torch.pow((input-target), 2)
self.l_adv = l2_loss
self.l_con = nn.L1Loss()
self.l_enc = l2_loss
self.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1])
self.err_g_con = self.l_con(self.fake, self.input)
self.err_g_enc = self.l_enc(self.latent_o, self.latent_i)
self.err_g = self.err_g_adv * self.opt.w_adv + 
             self.err_g_con * self.opt.w_con + 
             self.err_g_enc * self.opt.w_enc

 

上述代码为生成器损失函数代码,判别器的损失函数代码已经在理论部分为大家介绍了,这里就不在赘述了。

小结

这里我并没有很详细的为大家解读代码,但是把一些关键的部分都给大家介绍了。会了这些其实你完全可以自己实现一个网络,或者对我之前在中的代码稍加改造也可以达到一样的效果。论文中提供的源码感兴趣的大家可以自己去调试一下,代码量也不算多,但有的地方理解起来也有一定的困难,总之大家加油吧!!!

参考链接

: Semi- via

异常检测的经典之作|ACCV 2018

如若文章对你有所帮助,那就

转载:

323AI导航网发布

© 版权声明

相关文章

星河超算AI数字人

暂无评论

暂无评论...