譯者 | Sambodhi??
生成對抗網(wǎng)絡(luò)(Generative Adversarial Network,GAN)由 Goodfellow 等人在 2014 年提出,它徹底改變了計(jì)算機(jī)視覺中的圖像生成領(lǐng)域:沒有人能夠相信這些令人驚嘆而生動的圖像實(shí)際上是純粹由機(jī)器生成的。
事實(shí)上,人們曾經(jīng)認(rèn)為生成的任務(wù)是不可能的,并且被 GAN 的力量所震驚,因?yàn)閭鹘y(tǒng)上,根本沒有任何事實(shí)可以比較我們生成的圖像。
本文介紹了創(chuàng)建 GAN 背后的簡單直覺,然后介紹了通過 PyTorch 實(shí)現(xiàn)的卷積 GAN 及其訓(xùn)練過程。
GAN 背后的直覺
不同于傳統(tǒng)分類方法,我們的網(wǎng)絡(luò)預(yù)測可以直接與事實(shí)的正確答案相比較,而生成圖像的“正確性”是很難定義和衡量的。Goodfellow 等人在他們的原創(chuàng)論文《生成對抗網(wǎng)絡(luò)》(Generative Adversarial Network)中提出了一個(gè)有趣的想法:使用經(jīng)過訓(xùn)練的分類器來區(qū)分生成的圖像和實(shí)際圖像。如果存在這樣的分類器,我們可以創(chuàng)建并訓(xùn)練一個(gè)生成器網(wǎng)絡(luò),直到它輸出的圖像能完全騙過分類器。

圖 1 GAN 管道
GAN 是這一過程的產(chǎn)物:它包含一個(gè)根據(jù)給定的數(shù)據(jù)集生成圖像的生成器,以及一個(gè)區(qū)分圖像是真實(shí)的還是生成的判別器(分類器)。GAN 的詳細(xì)管道見圖 1。
損失函數(shù)
對生成器和判別器進(jìn)行優(yōu)化都很困難,因?yàn)檎缒闼胂蟮哪菢?,這兩個(gè)網(wǎng)絡(luò)的目標(biāo)完全相反:生成器希望盡可能地創(chuàng)造出真實(shí)的東西,但判別器希望區(qū)分生成的材料。
為了說明這一點(diǎn),我們讓 D(x) 是判別器的輸出,也就是 x 是真實(shí)圖像的概率,而 G(z) 是我們的生成器的輸出。判別器類似于一個(gè)二元分類器,因此判別器的目標(biāo)是使函數(shù)最大化:
本質(zhì)上是二元交叉熵?fù)p失,沒有開頭的負(fù)號。另一方面,生成器的目標(biāo)是使判別器做出正確判斷的機(jī)會最小化,因此它的目標(biāo)是最小化函數(shù)。所以,最終的損失函數(shù)將是兩個(gè)分類器之間的一個(gè)極小極大博弈(minimax game),具體如下:

從理論上講,這將收斂到判別器,預(yù)測所有事件的概率為 0.5。
但在實(shí)踐中,極小極大博弈往往會導(dǎo)致網(wǎng)絡(luò)無法收斂,因此仔細(xì)調(diào)整訓(xùn)練過程非常重要。像學(xué)習(xí)率這樣的超參數(shù)對于訓(xùn)練 GAN 時(shí)顯然更為重要:一個(gè)微小的變化會導(dǎo)致 GAN 產(chǎn)生一個(gè)輸出,而與輸入噪聲無關(guān)。
運(yùn)算環(huán)境
庫
我們通過 PyTorch 庫(包括 torchvision)來構(gòu)建整個(gè)程序。GAN 的生成結(jié)果的可視化是通過 Matplotlib 庫繪制的。下面的代碼導(dǎo)入了所有的庫:
importGAN.py
""" Import necessary libraries to create a generative adversarial network The code is mainly developed using the PyTorch library """ import time import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import transforms from model import discriminator, generator import numpy as np import matplotlib.pyplot as plt
數(shù)據(jù)集
在 GAN 訓(xùn)練中,數(shù)據(jù)集是一個(gè)重要方面。圖像的非結(jié)構(gòu)化性質(zhì)意味著任何給定的類別(如狗、貓或手寫的數(shù)字)都可以有一個(gè)可能的數(shù)據(jù)分布,而這種分布最終是 GAN 生成內(nèi)容的基礎(chǔ)。
為了演示,本文將使用最簡單的 MNIST 數(shù)據(jù)集,其中包含 60000 張從 0 到 9 的手寫數(shù)字圖像。事實(shí)上,像 MNIST 這樣的非結(jié)構(gòu)化數(shù)據(jù)集可以在 Graviti 上找到。這是一家年輕的創(chuàng)業(yè)公司,他們希望通過非結(jié)構(gòu)化數(shù)據(jù)集為社區(qū)提供幫助,在他們的 平臺 上有一些最好的公共非結(jié)構(gòu)化數(shù)據(jù)集,包括 MNIST。
硬件要求
最好的方法是用 GPU 訓(xùn)練神經(jīng)網(wǎng)絡(luò),它可以顯著地提高訓(xùn)練速度。但是,如果只有 CPU 可用,你仍然可以測試程序。要使你的程序能夠自行確定硬件,你可以使用以下方法:
torchDevice.py
""" Determine if any GPUs are available """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
實(shí)施
網(wǎng)絡(luò)架構(gòu)
由于數(shù)字的簡單性,這兩種架構(gòu)——判別器和生成器,都是由全連接層構(gòu)建的。請注意,在某些情況下,全連接的 GAN 也比 DCGAN 略微容易收斂。
以下是兩種架構(gòu)的 PyTorch 實(shí)現(xiàn):
GANArchitecture.py
"""
Network Architectures
The following are the discriminator and generator architectures
"""
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 1)
self.activation = nn.LeakyReLU(0.1)
def forward(self, x):
x = x.view(-1, 784)
x = self.activation(self.fc1(x))
x = self.fc2(x)
return nn.Sigmoid()(x)
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.fc1 = nn.Linear(128, 1024)
self.fc2 = nn.Linear(1024, 2048)
self.fc3 = nn.Linear(2048, 784)
self.activation = nn.ReLU()
def forward(self, x):
x = self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
x = self.fc3(x)
x = x.view(-1, 1, 28, 28)
return nn.Tanh()(x)
訓(xùn)練
在訓(xùn)練 GAN 時(shí),我們優(yōu)化了判別器的結(jié)果,同時(shí)也改進(jìn)了我們的生成器。這樣,在每次迭代過程中會有兩個(gè)相互矛盾的損失來同時(shí)優(yōu)化它們。我們送入生成器的是隨機(jī)噪聲,而生成器理應(yīng)根據(jù)給定噪聲的微小差異來生成圖像:
trainGAN.py
"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
for idx, (imgs, _) in enumerate(train_loader):
idx += 1
# Training the discriminator
# Real inputs are actual images of the MNIST dataset
# Fake inputs are from the generator
# Real inputs should be classified as 1 and fake as 0
real_inputs = imgs.to(device)
real_outputs = D(real_inputs)
real_label = torch.ones(real_inputs.shape[0], 1).to(device)
noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
outputs = torch.cat((real_outputs, fake_outputs), 0)
targets = torch.cat((real_label, fake_label), 0)
D_loss = loss(outputs, targets)
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# Training the generator
# For generator, goal is to make the discriminator believe everything is 1
noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
G_loss = loss(fake_outputs, fake_targets)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
if idx % 100 == 0 or idx == len(train_loader):
print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
if (epoch+1) % 10 == 0:
torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
print('Model saved.')
?結(jié)? 果
當(dāng) 100 個(gè)輪數(shù)(epoch)之后,我們可以繪制數(shù)據(jù)集,并看到從隨機(jī)噪音中生成的數(shù)字的結(jié)果:

圖 2:GAN 生成的結(jié)
如上圖所示,生成的結(jié)果看起來確實(shí)相當(dāng)像真實(shí)的結(jié)果。鑒于網(wǎng)絡(luò)非常簡單,所以結(jié)果看起來確實(shí)很有希望!
超越單純的內(nèi)容創(chuàng)作
GAN 的創(chuàng)造與計(jì)算機(jī)視覺領(lǐng)域的先前工作如此不同。隨后的眾多應(yīng)用使學(xué)術(shù)界對深度網(wǎng)絡(luò)的能力感到驚訝。下面將介紹一些令人驚訝的工作。
CycleGAN
Zhu 等人的 CycleGAN 引入了一種概念,它無需配對樣本就可以將圖像從 X 域翻譯成 Y 域。馬被轉(zhuǎn)化為斑馬,夏日的陽光被轉(zhuǎn)化為暴風(fēng)雪,CycleGAN 的結(jié)果令人驚訝且準(zhǔn)確。
GauGAN
Nvidia 利用 GAN 的力量,把簡單的繪畫,根據(jù)畫筆的語義,轉(zhuǎn)換成優(yōu)雅而逼真的照片。盡管訓(xùn)練資源的計(jì)算成本很高,但它創(chuàng)造了一個(gè)全新的研究和應(yīng)用領(lǐng)域。
AdvGAN
GAN 還擴(kuò)展到清理對抗性圖像,并將其轉(zhuǎn)化為不會欺騙分類器的干凈樣本。關(guān)于對抗性攻擊和防御的更多信息可以在 這里 到。
結(jié)? 語
所以,你已經(jīng)擁有了它!希望這篇文章對如何構(gòu)建 GAN 提供了一個(gè)概覽。
作者簡介:
Ta-ying Cheng,中國香港人,牛津大學(xué)哲學(xué)博士新生,愛好 3D 視覺、深度學(xué)習(xí)。
編輯:黃飛
?
電子發(fā)燒友App













評論