Pytorch学习笔记DCGAN极简入门教程

 更新时间:2021年09月07日 09:42:26   作者:xz1308579340  
网上GAN的教程太多了,这边也谈一下自己的理解,本文给大家介绍一下GAN的两部分组成,有需要的朋友可以借鉴参考下,希望能够有所帮助

1.图片分类网络

这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片

2.图片生成网络

输入是一个随机噪声,输出是一张图片,使用的是反卷积层

相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了

首先是图片分类网络:

简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input)

重点是生成网络

代码如下,其实就是反卷积+bn+relu

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
    def forward(self, input):
        return self.main(input)


讲道理,以上两个网络都挺简单。

真正的重点到了,怎么训练

每一个step分为三个步骤:

  • 训练二分类网络
    1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
    2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络
  • 训练生成网络
    3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络

不多说直接上代码

for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        iters += 1

以上就是Pytorch学习笔记DCGAN极简入门教程的详细内容,更多关于Pytorch学习DCGAN入门教程的资料请关注脚本之家其它相关文章!

相关文章

  • Python模块结构与布局操作方法实例分析

    Python模块结构与布局操作方法实例分析

    这篇文章主要介绍了Python模块结构与布局操作方法,结合实例形式分析了Python模块与布局的相关概念、使用方法与相关注意事项,需要的朋友可以参考下
    2017-07-07
  • Python中序列的修改、散列与切片详解

    Python中序列的修改、散列与切片详解

    在Python中,最基本的数据结构是序列(sequence)。下面这篇文章主要给大家介绍了关于Python中序列的修改、散列与切片的相关资料文中通过示例代码介绍的非常详细,需要的朋友可以参考,下面来一起看看吧。
    2017-08-08
  • Python开发之基于模板匹配的信用卡数字识别功能

    Python开发之基于模板匹配的信用卡数字识别功能

    这篇文章主要介绍了基于模板匹配的信用卡数字识别功能,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-01-01
  • Python调用golang代码详解

    Python调用golang代码详解

    这篇文章主要给大家介绍了关于Python调用golang代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2024-02-02
  • python实现贪吃蛇双人大战

    python实现贪吃蛇双人大战

    这篇文章主要为大家详细介绍了python实现贪吃蛇双人大战,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-04-04
  • Python中的远程调试与性能优化技巧分享

    Python中的远程调试与性能优化技巧分享

    Python 是一种简单易学、功能强大的编程语言,广泛应用于各种领域,包括网络编程、数据分析、人工智能等,在开发过程中,我们经常会遇到需要远程调试和性能优化的情况,本文将介绍如何利用远程调试工具和性能优化技巧来提高 Python 应用程序的效率和性能
    2024-05-05
  • Python实现将字典内容写入json文件

    Python实现将字典内容写入json文件

    这篇文章主要为大家详细介绍了如何利用Python语言实现将字典内容写入json文件,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2022-08-08
  • python 系统调用的实例详解

    python 系统调用的实例详解

    这篇文章主要介绍了python 系统调用的实例详解的相关资料,需要的朋友可以参考下
    2017-07-07
  • 一文教你用Pyecharts做交互图表

    一文教你用Pyecharts做交互图表

    Echarts 是一个由百度开源的数据可视化,凭借着良好的交互性,精巧的图表设计,得到了众多开发者的认可,本文介绍了Pyecharts交互图表,感兴趣的可以了解一下
    2021-05-05
  • python实用的快捷语法技巧大全

    python实用的快捷语法技巧大全

    初识Python语言,觉得python满足了我上学时候对编程语言的所有要求,下面这篇文章主要给大家介绍了关于python实用的快捷语法技巧的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-02-02

最新评论