출처 : 동빈나님 유튜브
star gan
: 다중 도메인에서 이미지 변환을 가능하게 한 단일 (=Star🌟) 모델(=GAN)이다
STEP 1. pix2pix 복습
pix2pix 에서 이미지의 특징을 라벨링하여서 넣어준다는 것이 조금 햇갈렸는데 pix2pix 도 cGAN 의 범주에 속하고, 그래서 condition 에 이런 특징을 담은 이미지라는 이미지를 넣어주는 것이라고 다시 이해하게 되었다.
즉 mnist 에서 “7의” 이미지를 만들어줘 가 pix2pix 에서는 “구두의” 이미지를 만들어줘 와 같은 원리인 것이다
“구두의” 를 설명하기 위해서 구두의 특징이 라벨링된 조건 정보를 넣어준다
StarGAN
하나의 Generator -> a 도메인
-> b 도메인
-> c 도메인
-> d 도메인
-> e 도메인
1) discriminator 의 수행 과제 2개
- 이미지가 진짜인지 가짜인지 (binary classification)
- domain classification
2) generator 학습 과정
-
input 형식 : Target domain (변환하고자 하는 도메인) , 인풋 이미지
항상 이 2개가 input 으로 들어가야 함 !
2-1) 인풋 (; 변환하고자 하는 도메인 b,원본 이미지) -> generator -> 도메인 b 로 변환된 이미지
2-2) 인풋 (; 다시 복원하고자 하는 원 도메인 a, 2-1 을 통해 b 도메인으로 변환된 이미지) -> generator
-> a 도메인의 원본 이미지 (;reconstructed image)
=> 즉 2-1, 2-2 과정은 cycle !
2-3) 2-1 에서 변환된 이미지는 또한 판별자에 들어가서 판별자를 속이도록 학습.
최종 loss
여기서 c’ 는 리얼 이미지의 도메인 a 를 의미
Mask Vector
domain class 에 mask vector 을 포함할 수 있다
mask vector 은 여러 데이터 셋 중 어떤 데이터 셋을 사용할지 명시하는 벡터이다
예를 들어, 해당 데이터 셋 CelebA 를 사용한다고 하면 속성은 흑발,금발,갈색,남성,어리다 가 있을 것이고
(1 은 예, 0 은 아니요)
[흑발인지 1/0,금발인지 1/0,갈색머리인지 1/0,남성인지 1/0,어린지 1/0,10(RaFD가 아니라 CelebA 이므로)]
이렇게 domain class vector 가 구성되게 된다.
파이토치 코드
출처 : 최윤제님 깃허브
꼼꼼히 보지는 못했는데 (나중에 식곤증 없는 타이밍에 다시 봐야겠다 !) x 이미지와 c 도메인 벡터를 같이 받아서 shape 을 정리해주고 concat 한 후 차례로 generator ,discriminator 모델에 들어가는 구조인 것 같다
모듈 import
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
모델에 계속해서 사용할 residual block 미리 만들어두기
class ResidualBlock(nn.Module):
"""Residual Block with instance normalization."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
def forward(self, x):
return x + self.main(x)
Generator
class Generator(nn.Module):
"""Generator network."""
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
super(Generator, self).__init__()
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
# Down-sampling layers.
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
# Bottleneck layers.
for i in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
# Up-sampling layers.
for i in range(2):
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim // 2
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)
def forward(self, x, c):
# Replicate spatially and concatenate domain information.
# Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
# This is because instance normalization ignores the shifting (or bias) effect.
c = c.view(c.size(0), c.size(1), 1, 1)
c = c.repeat(1, 1, x.size(2), x.size(3))
x = torch.cat([x, c], dim=1)
return self.main(x)
Discriminator
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers)
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
def forward(self, x):
h = self.main(x)
out_src = self.conv1(h)
out_cls = self.conv2(h)
return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))