출처 : 1시간만에 GAN 완전 정복하기

Unsupervised Learning

: gan 모델은 정답이 주어지지 않은 문제에서 스스로 도출한 답을 정답으로 내놓는 unsupervised learning 을 위한 모델이다. gan 은 주로 이미지 생성 등 생성 과제에 쓰이는데, 원 이미지 (=문제) 의 확률 분포를 학습함으로써 답(=가짜 이미지) 을 내놓을 수 있게 된다.

스크린샷 2021-06-13 오후 12 22 28

가령 어떤 이미지들을 담은 파일 📂 이 있다고 하였을 때, 이 파일에 있는 이미지들은 어떠한 확률 분포 (=특징) 을 보일 수 있다.

예를 들어, 안경 쓴 사람이 나온 이미지는 상대적으로 적을 수 있고 반대로, 금발인 여자는 굉장히 많이 등장할 수 있다.

Gan 모델의 목표

: 실제 이미지 데이터의 분포를 잘 근사하는 이미지를 생성해내는 것이 vanilla gan 모델의 목표이다

스크린샷 2021-06-13 오후 12 29 49

Intuition in Gan

Gan 내부에은 두 가지 모델이 존재한다. Discriminator 은 이미지가 가짜로 생성된 이미지인지, 진짜 이미지인지를 구별하는 과제를 맡는다. Generator 은 Discriminator 이 더 이상 판별하지 못할 정도로 진짜 이미지에 유사한 가짜 이미지를 만들어내는 과제를 맡는다.

스크린샷 2021-06-13 오후 12 32 06

수식으로 표현하면,

확률은 0~1 사이이므로 log 를 취했을 때, log(1) 일 때가 max 이고 log(0) 일 때는 음의 무한대까지 내려간다.

x ~ Pdata(x) ; 실제 data 의 분포로부터 샘플링한다

z ~ Pz(z) ; 원 사진과 같은 차원의 가우시안 분포로부터 랜덤하게 샘플링한다

Discriminator 의 loss function 은 다음과 같다. Discriminator 과 Generator 이 같은 loss function 을 공유하는데, Discriminatordms 이 loss function 값을 최대화시켜놓고자한다.

스크린샷 2021-06-13 오후 12 36 00

generator 의 loss function 은 다음과같다. generator 은 이 loss function 값을 최소화시키고자 한다.

스크린샷 2021-06-13 오후 12 40 48

그런데 여기에는 조금의 함정이 있다.

맨 처음 generator 가 학습할 때의 상황을 고려하면, 초반에는 형편없는 이미지를 만들어내므로 Discriminator 한테 바로 딱 걸릴 것이다. D(g(z)) = 0

하지만 이 때 기울기를 찍어보면 생각보다 작다는 것을 알 수 있다 = 형편 없으므로 학습을 엄청 해야하는데 생각보다 많이 학습이 되지 않는 상황 ?

스크린샷 2021-06-13 오후 12 59 46

이를 해결하기 위해서, logD(g(z)) 자체를 최대화시켜주는 쪽으로 학습을 진행하는 방법도 있다. = 다른 말로 하면 -logD(g(z)) 를 최소화시켜주는 방식

(-log(1)=0 -> -log(0)=양의 무한대)

마찬가지로 학습 초반에 D(g(z)) = 0 인데 이 자체를 이용한다면 기울기가 훨씬 커져서 이전 버전보다 좀 더 빨리 좋은 이미지를 만들어내고자 학습을 하게 된다.

스크린샷 2021-06-13 오후 1 02 20

Why does GANs work?

두 확률 분포의 차이를 나타내는 KL divergence ; P(x) * log ( P(x) / Q(x) )

이 때 kl divergence 는 항상 원 확률분포에서의 엔트로피에서 바뀐 확률분포를 가정했을 때 바뀐 엔트로피를 빼주는 순서로 값이 나온다. 즉 Dkl(p,q) != Dkl(q,p) 이다.

Jenson-Shannon Divergence 는 이에 비해 Djs(p,q) == Djs(q,p) 가 되게끔 대칭을 맞춰주었다.

; JSD(P,Q)=1/2 * KL(P,M) + 1/2* KL(Q,M)

이 때 M=1/2*(P+Q)

결국 loss function 을 최적화시키는 과정이 JSD 에서 원 데이터의 분포와 generator 가 만드는 이미지의 분포의 차이를 최소화하는 쪽으로 유도된다는 것이다.

스크린샷 2021-06-13 오후 1 22 18

구체적인 증명은 동빈나님 유튜브 강의를 참고하였다.

출처 : GAN: Generative Adversarial Netsorks

스크린샷 2021-06-13 오후 1 26 48

스크린샷 2021-06-13 오후 1 29 50

GAN pytorch 코드

출처 : 최윤제님 파이토치 튜토리얼

모듈 import

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

cuda 연결하기

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

하이퍼파라미터

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

transform 으로 normalize 시킨 mnist 이미지 받아오기

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5],   # 1 for greyscale channels
                                     std=[0.5])])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

Discriminator

처음에 이미지를 Linear 층을 통해 펴주고, 중간 중간 activation function 을 넣어주고 discriminator 의 주 임무는 판별이므로 마지막엔 sigmoid 를 붙여준다

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

Generator

처음에 특정 차원 벡터로부터 점점 이미지 사이즈에 맞는 벡터를 생성해내도록 activation function 과 linear layer 을 조합해준다.

# Generator 
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

(cpu 혹은) gpu 를 할당해준다

# Device setting
D = D.to(device)
G = G.to(device)

Loss function, 최적화 방법을 지정해준다. 동일 Loss function 을 공유하고 있을 뿐, 각각 목적에 맞게 최적화를 해줄 것이기 때문에 자신의 model 에 해당하는 parameters 만 최적화시키겠다고 설정해준다.

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

pytorch 에셔는 잊지 말고 zero_grad() 를 계속 시켜줘야하므로 미리 편하도록 함수 코드를 만들어주었다. denorm 코드도 미리 만들어주었다.

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

에폭 넣어주고 for 문 돌려서 Discriminator, generator 훈련시키기

total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

스크린샷 2021-06-13 오후 12 36 00

먼저 Discriminator 훈련 과정을 살펴보면,

outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
        
# Compute BCELoss using fake images
# First term of the loss is always zero since fake_labels == 0
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
        
# Backprop and optimize
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()

Generator 훈련 과정을 살펴보면,

스크린샷 2021-06-13 오후 12 40 48

# Compute loss with fake images
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
        
# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
g_loss = criterion(outputs, real_labels)
        
# Backprop and optimize
reset_grad()
g_loss.backward()
g_optimizer.step()
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))