출처 : 동빈나님 유튜브

star gan

: 다중 도메인에서 이미지 변환을 가능하게 한 단일 (=Star🌟) 모델(=GAN)이다

STEP 1. pix2pix 복습

pix2pix 에서 이미지의 특징을 라벨링하여서 넣어준다는 것이 조금 햇갈렸는데 pix2pix 도 cGAN 의 범주에 속하고, 그래서 condition 에 이런 특징을 담은 이미지라는 이미지를 넣어주는 것이라고 다시 이해하게 되었다.

즉 mnist 에서 “7의” 이미지를 만들어줘 가 pix2pix 에서는 “구두의” 이미지를 만들어줘 와 같은 원리인 것이다

“구두의” 를 설명하기 위해서 구두의 특징이 라벨링된 조건 정보를 넣어준다

선택 영역_202

선택 영역_203

StarGAN

선택 영역_204

하나의 Generator -> a 도메인

​ -> b 도메인

​ -> c 도메인

​ -> d 도메인

​ -> e 도메인

선택 영역_205

1) discriminator 의 수행 과제 2개

  • 이미지가 진짜인지 가짜인지 (binary classification)
  • domain classification

2) generator 학습 과정

  • input 형식 : Target domain (변환하고자 하는 도메인) , 인풋 이미지

    항상 이 2개가 input 으로 들어가야 함 !

    2-1) 인풋 (; 변환하고자 하는 도메인 b,원본 이미지) -> generator -> 도메인 b 로 변환된 이미지

    2-2) 인풋 (; 다시 복원하고자 하는 원 도메인 a, 2-1 을 통해 b 도메인으로 변환된 이미지) -> generator

    -> a 도메인의 원본 이미지 (;reconstructed image)

    => 즉 2-1, 2-2 과정은 cycle !

    2-3) 2-1 에서 변환된 이미지는 또한 판별자에 들어가서 판별자를 속이도록 학습.

최종 loss

선택 영역_206

여기서 c’ 는 리얼 이미지의 도메인 a 를 의미

Mask Vector

domain class 에 mask vector 을 포함할 수 있다

mask vector 은 여러 데이터 셋 중 어떤 데이터 셋을 사용할지 명시하는 벡터이다

예를 들어, 해당 데이터 셋 CelebA 를 사용한다고 하면 속성은 흑발,금발,갈색,남성,어리다 가 있을 것이고

(1 은 예, 0 은 아니요)

[흑발인지 1/0,금발인지 1/0,갈색머리인지 1/0,남성인지 1/0,어린지 1/0,10(RaFD가 아니라 CelebA 이므로)]

이렇게 domain class vector 가 구성되게 된다.

선택 영역_207

선택 영역_208

파이토치 코드

출처 : 최윤제님 깃허브

꼼꼼히 보지는 못했는데 (나중에 식곤증 없는 타이밍에 다시 봐야겠다 !) x 이미지와 c 도메인 벡터를 같이 받아서 shape 을 정리해주고 concat 한 후 차례로 generator ,discriminator 모델에 들어가는 구조인 것 같다

모듈 import

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

모델에 계속해서 사용할 residual block 미리 만들어두기

class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)

Generator

class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
        super(Generator, self).__init__()

        layers = []
        layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck layers.
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        # Up-sampling layers.
        for i in range(2):
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c):
        # Replicate spatially and concatenate domain information.
        # Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
        # This is because instance normalization ignores the shifting (or bias) effect.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        return self.main(x)

Discriminator

class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
            curr_dim = curr_dim * 2

        kernel_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
        
    def forward(self, x):
        h = self.main(x)
        out_src = self.conv1(h)
        out_cls = self.conv2(h)
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))