출처 : 같이 gan 과기대 대회에 나가는 팀원 승원님과 해당 유튜브 링크의 친절한 설명 덕분에 잘 이해할 수 있었습니다. 🙇‍♀️ CycleGan 논문 리뷰-이재환 님

cycle gan ⛳️

: cycle gan 은 이미지를 각기 다른 2개의 도메인에 맞게 변환시켜주는 모델이다.

해당 그림이 사진이었다면 어떻게 나올까 ?

ex1 . input ) A 도메인 (그림) -> B 도메인 (사진)

해당 말이 얼룩말이었다면 어떻게 생겼을까 ?

ex2 . input ) A 도메인 (말) -> B 도메인 (얼룩말)

물론 B 도메인을 input 으로 넣어서 A 도메인으로 변환할 수도 있다.

STEP 1. pix2pix

선택 영역_193

input ) pixel 별로 특징을 라벨링한 데이터 (ex. 창문, 문 등이 라벨링 되어있음) ->

output ) 이 데이터를 바탕으로 생성된 사진 (라벨링된 부분이 특정 픽셀로 생성되었을 것임)

loss : ||y - G(x) ||

하지만 실제 정답에 비해 명확하지 않은 데이터가 생성되는 한계를 보이는 데, 라벨링 된 부분의 구체적 픽셀 값을 추측하기 어려울 때 중간 값 정도로 생성하기 때문

ex. 흑백 사진에서 새를 라벨링한 데이터 -> 실제 픽셀이 무슨 색일지 추측하기 어렵기 때문에 애매한 중간값을 색으로 가지는 새를 생성 (하지만 실제 새가 채도가 높은 진한 노란색이었다면 ?! 문제가 된다)

STEP 2. pix2pix 에 GAN loss 결합

Gan loss 리마인드

선택 영역_194

저번 설명과 바뀐 점 : 해당 식에선 D(z)=1 이 fake 이다.

pix2pix 에 GAN loss 결합

선택 영역_195

CycleGAN

Gan loss : 진짜 같은 가짜 이미지를 생성

하지만

해당 말이 얼룩말이었다면 어떻게 생겼을까 ? 에 Gan loss 를 적용하면

어떤 말이든 상관없이 모두 똑같은 얼룩말 이미지로 바뀔 수가 있다. 애초에 gan 은 임의의 G(x)==y 로 만드는 것이 목적이므로 input x 이미지에 상관없이 y 값 이미지로 만드는 데에만 집중하기 때문이다.

G(x1)==y

G(x2)==y

인 상황 발생 !

그래서 보완된 CycleGan 의 loss

선택 영역_196

1) G(x1),G(x2) == y ⛳️ : a 도메인의 이미지를 최대한 진짜같은 b 도메인의 이미지로 만들고자 함

여기선 아직 다른 input 이 들어가도 일단 리얼한 b 도메인의 이미지로 변환하는 것만 목표

2) || F(G(x)) -x || ⛳️ : b 도메인으로 변환된 이미지를 다시 a 도메인으로 옮겼을 때 최대한 원래 그대로가 나오도록 함

input 에 따라서 b 도메인 내에서도 각기 다른 이미지로 생성되로록 함

G(x1) -> y1 -> F(y1) == x1

G(x2) -> y2 -> F(y2) == x2

마찬가지로 || G(F(y)) -y || loss 도 적용 !

F(y1) -> x1 -> G(x1) == y1

F(y2) -> x2 -> G(x2) == y2

선택 영역_197

최종 정리해보면 총 loss 4개 (gan loss 2개, 구체적으로 특정 이미지로 바꾸게 하는 loss 2개)

1) Gan loss 2개 (optional : ls gan loss)

G(x1),G(x2) == y ⛳️ : a 도메인의 이미지를 최대한 진짜같은 b 도메인의 이미지로 만들고자 함

G(y1),G(y2) == x ⛳️ : b 도메인의 이미지를 최대한 진짜같은 a 도메인의 이미지로 만들고자 함

2) 특정 이미지로 바꾸게 하는 (=cycle) loss 2개

G(x1) -> y1 -> F(y1) == x1

G(x2) -> y2 -> F(y2) == x2

F(y1) -> x1 -> G(x1) == y1

F(y2) -> x2 -> G(x2) == y2

Details

Generator 모델 구성 : ResNet

Gan loss -> LS gan loss (ls gan 도 언젠가 정리해야겠다 ! 😅)

추가적인 loss : identity loss

identity loss ; cycle loss 와 더불어 더 디테일을 살리기 위한 loss

G(y1) -> y1 과 얼마나 비슷할지

G(y2) -> y2 와 얼마나 비슷할지

G(x1) -> x1 과 얼마나 비슷할지

G(x2) -> x2 와 얼마나 비슷할지

최대한 디테일까지 잡는 것이 목표 !

선택 영역_198

Cycle Gan 이 만능은 아니다 !

cycle gan 으로 모양을 바꾸는 것은 아직 한계

확률 분포가 작은 이미지 변환에는 취약

Cycle Gan 코드 (텐서플로우)

모듈 import , init

  • RandomNormal : tf.keras.initializers.RandomNormal는 표준분포를 갖는 텐서 (tensor)를 생성합니다.

    시냅스 가중치 적용하기

  • lambda_ validation : gan loss 에 적용할 가중치
  • lambda_reconstr : cycle loss 에 적용할 가중치
  • lambda_id : identity loss 에 적용할 가중치
  • gen_n_filters : generator 첫 번째 layer 에 들어갈 필터 수
  • disc_n_filters : discriminator 첫 번째 layer 에 들어갈 필터 수
  • deque : 파이썬에서 제공하는 큐 자료구조, maxlen ; maxlen=n은 deque의 최대 길이를 n으로 제한합니다.
from __future__ import print_function, division
import scipy

from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import LeakyReLU, ELU
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv2DTranspose
from keras.layers.merge import add
from models.layers.layers import ReflectionPadding2D
from keras.models import Sequential, Model
from keras.initializers import RandomNormal
from keras.optimizers import Adam
from keras import backend as K

from keras.utils import plot_model

import datetime
import matplotlib.pyplot as plt
import sys

import numpy as np
import os
import pickle as pkl
import random

from collections import deque

매개변수로 넣어줄 애들 (하이퍼파라미터) :

  • input_dim
  • learning_rate
  • lambda_validation
  • lambda_reconstr
  • lambda_id
  • generator_type
  • gen_n_filters
  • disc_n_filters
  • buffer_max_length
class CycleGAN():
    def __init__(self
        , input_dim
        , learning_rate
        , lambda_validation
        , lambda_reconstr
        , lambda_id
        , generator_type
        , gen_n_filters
        , disc_n_filters
        , buffer_max_length = 50
        ):
        self.input_dim = input_dim
        self.learning_rate = learning_rate
        self.buffer_max_length = buffer_max_length
        self.lambda_validation = lambda_validation
        self.lambda_reconstr = lambda_reconstr
        self.lambda_id = lambda_id
        self.generator_type = generator_type
        self.gen_n_filters = gen_n_filters
        self.disc_n_filters = disc_n_filters

input shape

 		self.img_rows = input_dim[0]
        self.img_cols = input_dim[1]
        self.channels = input_dim[2]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

loss 를 담아줄 변수, 버퍼

 		self.d_losses = []
        self.g_losses = []
        self.epoch = 0

        self.buffer_A = deque(maxlen = self.buffer_max_length)
        self.buffer_B = deque(maxlen = self.buffer_max_length)
        

patch gan

 # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**3)
        self.disc_patch = (patch, patch, 1)

weight initialization

self.weight_init = RandomNormal(mean=0., stddev=0.02)

다 되었으면 compile

self.compile_models()

compile_models

G(x1),G(x2) == y ⛳️ : a 도메인의 이미지를 최대한 진짜같은 b 도메인의 이미지로 만들고자 함

사용 모델 : (self.d_B , self.g_AB)

G(y1),G(y2) == x ⛳️ : b 도메인의 이미지를 최대한 진짜같은 a 도메인의 이미지로 만들고자 함

사용 모델 : (self.d_A , self.g_BA)

LS Gan 이므로 loss 는 mse

def compile_models(self):

        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        
        self.d_A.compile(loss='mse',
            optimizer=Adam(self.learning_rate, 0.5),
            metrics=['accuracy'])
        self.d_B.compile(loss='mse',
            optimizer=Adam(self.learning_rate, 0.5),
            metrics=['accuracy'])

# Build the generators
        if self.generator_type == 'unet':
            self.g_AB = self.build_generator_unet()
            self.g_BA = self.build_generator_unet()
        else:
            self.g_AB = self.build_generator_resnet()
            self.g_BA = self.build_generator_resnet()

generator 학습할 때 discriminator 꺼놓기

# For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

input 넣어주기

 # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

모델에서 나온 이미지 변수에 할당

G(x1) -> y1 -> F(y1) == x1

G(x2) -> y2 -> F(y2) == x2 ; reconstr_A

F(y1) -> x1 -> G(x1) == y1

F(y2) -> x2 -> G(x2) == y2 ; reconstr_B

G(y1) -> y1 과 얼마나 비슷할지

G(y2) -> y2 와 얼마나 비슷할지 ; img_B_id

G(x1) -> x1 과 얼마나 비슷할지

G(x2) -> x2 와 얼마나 비슷할지 ; img_A_id

 # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

model compile

 # Combined model trains generators to fool discriminators
 self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
 self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[  self.lambda_validation,                       self.lambda_validation,
                                            self.lambda_reconstr, self.lambda_reconstr,
                                            self.lambda_id, self.lambda_id ],
                            optimizer=Adam(0.0002, 0.5))

compile 다 했으므로 discriminator 다시 켜주기

self.d_A.trainable = True
self.d_B.trainable = True

build generator

  • UpSampling2D: 단순히 이미지의 가로, 세로 크기를 2배씩 늘려주는 기능을 한다

    Nearest Neighbor : Dense 데이터를 그대로 늘려서, 빈 구역에 채워넣는 것입니다.

    1   2  
           
    3   4  
           

    ->

    1 1 2 2
    1 1 2 2
    3 3 4 4
    3 3 4 4

    참고 : upsampling 간단 설명

  • Concatenate: default axis=-1

U-Net 구조는 다음과 같다고 한다. 뭐 이렇게 해보면 이런 구조대로 잘 되지 않을까 ㅋㅋ😅

선택 영역_199

def build_generator_unet(self):

        def downsample(layer_input, filters, f_size=4):
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = InstanceNormalization(axis = -1, center = False, scale = False)(d)
            d = Activation('relu')(d)
            
            return d

        def upsample(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same')(u)
            u = InstanceNormalization(axis = -1, center = False, scale = False)(u)
            u = Activation('relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)

            u = Concatenate()([u, skip_input])
            return u

        # Image input
        img = Input(shape=self.img_shape)

        # Downsampling
        d1 = downsample(img, self.gen_n_filters) 
        d2 = downsample(d1, self.gen_n_filters*2)
        d3 = downsample(d2, self.gen_n_filters*4)
        d4 = downsample(d3, self.gen_n_filters*8)

        # Upsampling
        u1 = upsample(d4, d3, self.gen_n_filters*4)
        u2 = upsample(u1, d2, self.gen_n_filters*2)
        u3 = upsample(u2, d1, self.gen_n_filters)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

        return Model(img, output_img)

resnet 으로 만든 generator 은 생략하였다 😅 집에 가고 싶다 ㅠ

build discriminator

discriminator 은 훨씬 더 간단하다

선택 영역_200

def build_discriminator(self):

        def conv4(layer_input,filters, stride = 2, norm=True):
            y = Conv2D(filters, kernel_size=(4,4), strides=stride, padding='same', kernel_initializer = self.weight_init)(layer_input)
            
            if norm:
                y = InstanceNormalization(axis = -1, center = False, scale = False)(y)

            y = LeakyReLU(0.2)(y)
           
            return y

        img = Input(shape=self.img_shape)

        y = conv4(img, self.disc_n_filters, stride = 2, norm = False)
        y = conv4(y, self.disc_n_filters*2, stride = 2)
        y = conv4(y, self.disc_n_filters*4, stride = 2)
        y = conv4(y, self.disc_n_filters*8, stride = 1)

        output = Conv2D(1, kernel_size=4, strides=1, padding='same',kernel_initializer = self.weight_init)(y)

        return Model(img, output)

train discriminator

buffer : “너가 예전에 이렇게 말했었어 !” discriminator 을 더 잘 훈련시키기 위한 과거 기억 저장

  • self.buffer_B : generator 이 A 도메인 이미지를 받아서 B 도메인 이미지로 변환한 것의 모음

    fake_B = self.g_AB.predict(imgs_A)

  • self.buffer_A : generator 이 B 도메인 이미지를 받아서 A 도메인 이미지로 변환한 것의 모음

    fake_A = self.g_BA.predict(imgs_B)

  • random.sample(컬렉션,샘플 수): random 모듈에서 sample(컬렉션, 샘플수) 함수는 지정된 컬렉션으로부터 샘플수만큼 랜덤 추출을 하는 함수

  • model.train_on_batch : as the name implies, trains only one batch

    Returns

    Scalar training loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics).

     def train_discriminators(self, imgs_A, imgs_B, valid, fake):
      
            # Translate images to opposite domain
            fake_B = self.g_AB.predict(imgs_A)
            fake_A = self.g_BA.predict(imgs_B)
      
            self.buffer_B.append(fake_B)
            self.buffer_A.append(fake_A)
      
            fake_A_rnd = random.sample(self.buffer_A, min(len(self.buffer_A), len(imgs_A)))
            fake_B_rnd = random.sample(self.buffer_B, min(len(self.buffer_B), len(imgs_B)))
      
            # Train the discriminators (original images = real / translated = Fake)
            dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
            dA_loss_fake = self.d_A.train_on_batch(fake_A_rnd, fake)
            dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
      
            dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
            dB_loss_fake = self.d_B.train_on_batch(fake_B_rnd, fake)
            dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
      
            # Total disciminator loss
            d_loss_total = 0.5 * np.add(dA_loss, dB_loss)
            #[0] : loss , [1] : accuracy
            return (
                d_loss_total[0]
                , dA_loss[0], dA_loss_real[0], dA_loss_fake[0]
                , dB_loss[0], dB_loss_real[0], dB_loss_fake[0]
                , d_loss_total[1]
                , dA_loss[1], dA_loss_real[1], dA_loss_fake[1]
                , dB_loss[1], dB_loss_real[1], dB_loss_fake[1]
            )
      
    

train generator

outputs=[ valid_A, valid_B,reconstr_A, reconstr_B,img_A_id, img_B_id ]

이었으므로 loss 6개 나옴 ! gan loss 2개, cycle loss 2개, identity loss 2개

def train_generators(self, imgs_A, imgs_B, valid):

        # Train the generators
        return self.combined.train_on_batch([imgs_A, imgs_B],
                                                [valid, valid,
                                                imgs_A, imgs_B,
                                                imgs_A, imgs_B])

최종 train 함수

def train(self, data_loader, run_folder, epochs, test_A_file, test_B_file, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(self.epoch, epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch()):

                d_loss = self.train_discriminators(imgs_A, imgs_B, valid, fake)
                g_loss = self.train_generators(imgs_A, imgs_B, valid)

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                    % ( self.epoch, epochs,
                        batch_i, data_loader.n_batches,
                        d_loss[0], 100*d_loss[7],
                        g_loss[0],
                        np.sum(g_loss[1:3]),
                        np.sum(g_loss[3:5]),
                        np.sum(g_loss[5:7]),
                        elapsed_time))

                self.d_losses.append(d_loss)
                self.g_losses.append(g_loss)

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(data_loader, batch_i, run_folder, test_A_file, test_B_file)
                    self.combined.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (self.epoch)))
                    self.combined.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
                    self.save_model(run_folder)

                
            self.epoch += 1

모델 구성 코드 보기도 버거워서 그 이후 부가적인 save 나 이미지 plotting 부분은 생략하였다 ! 🤪