출처 : 같이 gan 과기대 대회에 나가는 팀원 승원님과 해당 유튜브 링크의 친절한 설명 덕분에 잘 이해할 수 있었습니다. 🙇♀️ CycleGan 논문 리뷰-이재환 님
cycle gan ⛳️
: cycle gan 은 이미지를 각기 다른 2개의 도메인에 맞게 변환시켜주는 모델이다.
해당 그림이 사진이었다면 어떻게 나올까 ?
ex1 . input ) A 도메인 (그림) -> B 도메인 (사진)
해당 말이 얼룩말이었다면 어떻게 생겼을까 ?
ex2 . input ) A 도메인 (말) -> B 도메인 (얼룩말)
물론 B 도메인을 input 으로 넣어서 A 도메인으로 변환할 수도 있다.
STEP 1. pix2pix
input ) pixel 별로 특징을 라벨링한 데이터 (ex. 창문, 문 등이 라벨링 되어있음) ->
output ) 이 데이터를 바탕으로 생성된 사진 (라벨링된 부분이 특정 픽셀로 생성되었을 것임)
loss : ||y - G(x) ||
하지만 실제 정답에 비해 명확하지 않은 데이터가 생성되는 한계를 보이는 데, 라벨링 된 부분의 구체적 픽셀 값을 추측하기 어려울 때 중간 값 정도로 생성하기 때문
ex. 흑백 사진에서 새를 라벨링한 데이터 -> 실제 픽셀이 무슨 색일지 추측하기 어렵기 때문에 애매한 중간값을 색으로 가지는 새를 생성 (하지만 실제 새가 채도가 높은 진한 노란색이었다면 ?! 문제가 된다)
STEP 2. pix2pix 에 GAN loss 결합
Gan loss 리마인드
저번 설명과 바뀐 점 : 해당 식에선 D(z)=1 이 fake 이다.
pix2pix 에 GAN loss 결합
CycleGAN
Gan loss : 진짜 같은 가짜 이미지를 생성
하지만
해당 말이 얼룩말이었다면 어떻게 생겼을까 ? 에 Gan loss 를 적용하면
어떤 말이든 상관없이 모두 똑같은 얼룩말 이미지로 바뀔 수가 있다. 애초에 gan 은 임의의 G(x)==y 로 만드는 것이 목적이므로 input x 이미지에 상관없이 y 값 이미지로 만드는 데에만 집중하기 때문이다.
G(x1)==y
G(x2)==y
인 상황 발생 !
그래서 보완된 CycleGan 의 loss
1) G(x1),G(x2) == y ⛳️ : a 도메인의 이미지를 최대한 진짜같은 b 도메인의 이미지로 만들고자 함
여기선 아직 다른 input 이 들어가도 일단 리얼한 b 도메인의 이미지로 변환하는 것만 목표
2) || F(G(x)) -x || ⛳️ : b 도메인으로 변환된 이미지를 다시 a 도메인으로 옮겼을 때 최대한 원래 그대로가 나오도록 함
input 에 따라서 b 도메인 내에서도 각기 다른 이미지로 생성되로록 함
G(x1) -> y1 -> F(y1) == x1
G(x2) -> y2 -> F(y2) == x2
마찬가지로 || G(F(y)) -y || loss 도 적용 !
F(y1) -> x1 -> G(x1) == y1
F(y2) -> x2 -> G(x2) == y2
최종 정리해보면 총 loss 4개 (gan loss 2개, 구체적으로 특정 이미지로 바꾸게 하는 loss 2개)
1) Gan loss 2개 (optional : ls gan loss)
G(x1),G(x2) == y ⛳️ : a 도메인의 이미지를 최대한 진짜같은 b 도메인의 이미지로 만들고자 함
G(y1),G(y2) == x ⛳️ : b 도메인의 이미지를 최대한 진짜같은 a 도메인의 이미지로 만들고자 함
2) 특정 이미지로 바꾸게 하는 (=cycle) loss 2개
G(x1) -> y1 -> F(y1) == x1
G(x2) -> y2 -> F(y2) == x2
F(y1) -> x1 -> G(x1) == y1
F(y2) -> x2 -> G(x2) == y2
Details
Generator 모델 구성 : ResNet
Gan loss -> LS gan loss (ls gan 도 언젠가 정리해야겠다 ! 😅)
추가적인 loss : identity loss
identity loss ; cycle loss 와 더불어 더 디테일을 살리기 위한 loss
G(y1) -> y1 과 얼마나 비슷할지
G(y2) -> y2 와 얼마나 비슷할지
G(x1) -> x1 과 얼마나 비슷할지
G(x2) -> x2 와 얼마나 비슷할지
최대한 디테일까지 잡는 것이 목표 !
Cycle Gan 이 만능은 아니다 !
cycle gan 으로 모양을 바꾸는 것은 아직 한계
확률 분포가 작은 이미지 변환에는 취약
Cycle Gan 코드 (텐서플로우)
모듈 import , init
-
RandomNormal : tf.keras.initializers.RandomNormal는 표준분포를 갖는 텐서 (tensor)를 생성합니다.
- lambda_ validation : gan loss 에 적용할 가중치
- lambda_reconstr : cycle loss 에 적용할 가중치
- lambda_id : identity loss 에 적용할 가중치
- gen_n_filters : generator 첫 번째 layer 에 들어갈 필터 수
- disc_n_filters : discriminator 첫 번째 layer 에 들어갈 필터 수
- deque : 파이썬에서 제공하는 큐 자료구조, maxlen ; maxlen=n은 deque의 최대 길이를 n으로 제한합니다.
from __future__ import print_function, division
import scipy
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import LeakyReLU, ELU
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv2DTranspose
from keras.layers.merge import add
from models.layers.layers import ReflectionPadding2D
from keras.models import Sequential, Model
from keras.initializers import RandomNormal
from keras.optimizers import Adam
from keras import backend as K
from keras.utils import plot_model
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
import pickle as pkl
import random
from collections import deque
매개변수로 넣어줄 애들 (하이퍼파라미터) :
- input_dim
- learning_rate
- lambda_validation
- lambda_reconstr
- lambda_id
- generator_type
- gen_n_filters
- disc_n_filters
- buffer_max_length
class CycleGAN():
def __init__(self
, input_dim
, learning_rate
, lambda_validation
, lambda_reconstr
, lambda_id
, generator_type
, gen_n_filters
, disc_n_filters
, buffer_max_length = 50
):
self.input_dim = input_dim
self.learning_rate = learning_rate
self.buffer_max_length = buffer_max_length
self.lambda_validation = lambda_validation
self.lambda_reconstr = lambda_reconstr
self.lambda_id = lambda_id
self.generator_type = generator_type
self.gen_n_filters = gen_n_filters
self.disc_n_filters = disc_n_filters
input shape
self.img_rows = input_dim[0]
self.img_cols = input_dim[1]
self.channels = input_dim[2]
self.img_shape = (self.img_rows, self.img_cols, self.channels)
loss 를 담아줄 변수, 버퍼
self.d_losses = []
self.g_losses = []
self.epoch = 0
self.buffer_A = deque(maxlen = self.buffer_max_length)
self.buffer_B = deque(maxlen = self.buffer_max_length)
patch gan
# Calculate output shape of D (PatchGAN)
patch = int(self.img_rows / 2**3)
self.disc_patch = (patch, patch, 1)
weight initialization
self.weight_init = RandomNormal(mean=0., stddev=0.02)
다 되었으면 compile
self.compile_models()
compile_models
G(x1),G(x2) == y ⛳️ : a 도메인의 이미지를 최대한 진짜같은 b 도메인의 이미지로 만들고자 함
사용 모델 : (self.d_B , self.g_AB)
G(y1),G(y2) == x ⛳️ : b 도메인의 이미지를 최대한 진짜같은 a 도메인의 이미지로 만들고자 함
사용 모델 : (self.d_A , self.g_BA)
LS Gan 이므로 loss 는 mse
def compile_models(self):
# Build and compile the discriminators
self.d_A = self.build_discriminator()
self.d_B = self.build_discriminator()
self.d_A.compile(loss='mse',
optimizer=Adam(self.learning_rate, 0.5),
metrics=['accuracy'])
self.d_B.compile(loss='mse',
optimizer=Adam(self.learning_rate, 0.5),
metrics=['accuracy'])
# Build the generators
if self.generator_type == 'unet':
self.g_AB = self.build_generator_unet()
self.g_BA = self.build_generator_unet()
else:
self.g_AB = self.build_generator_resnet()
self.g_BA = self.build_generator_resnet()
generator 학습할 때 discriminator 꺼놓기
# For the combined model we will only train the generators
self.d_A.trainable = False
self.d_B.trainable = False
input 넣어주기
# Input images from both domains
img_A = Input(shape=self.img_shape)
img_B = Input(shape=self.img_shape)
모델에서 나온 이미지 변수에 할당
G(x1) -> y1 -> F(y1) == x1
G(x2) -> y2 -> F(y2) == x2 ; reconstr_A
F(y1) -> x1 -> G(x1) == y1
F(y2) -> x2 -> G(x2) == y2 ; reconstr_B
G(y1) -> y1 과 얼마나 비슷할지
G(y2) -> y2 와 얼마나 비슷할지 ; img_B_id
G(x1) -> x1 과 얼마나 비슷할지
G(x2) -> x2 와 얼마나 비슷할지 ; img_A_id
# For the combined model we will only train the generators
self.d_A.trainable = False
self.d_B.trainable = False
# Input images from both domains
img_A = Input(shape=self.img_shape)
img_B = Input(shape=self.img_shape)
# Translate images to the other domain
fake_B = self.g_AB(img_A)
fake_A = self.g_BA(img_B)
# Translate images back to original domain
reconstr_A = self.g_BA(fake_B)
reconstr_B = self.g_AB(fake_A)
# Identity mapping of images
img_A_id = self.g_BA(img_A)
img_B_id = self.g_AB(img_B)
# Discriminators determines validity of translated images
valid_A = self.d_A(fake_A)
valid_B = self.d_B(fake_B)
model compile
# Combined model trains generators to fool discriminators
self.combined = Model(inputs=[img_A, img_B],
outputs=[ valid_A, valid_B,
reconstr_A, reconstr_B,
img_A_id, img_B_id ])
self.combined.compile(loss=['mse', 'mse',
'mae', 'mae',
'mae', 'mae'],
loss_weights=[ self.lambda_validation, self.lambda_validation,
self.lambda_reconstr, self.lambda_reconstr,
self.lambda_id, self.lambda_id ],
optimizer=Adam(0.0002, 0.5))
compile 다 했으므로 discriminator 다시 켜주기
self.d_A.trainable = True
self.d_B.trainable = True
build generator
-
UpSampling2D: 단순히 이미지의 가로, 세로 크기를 2배씩 늘려주는 기능을 한다
Nearest Neighbor : Dense 데이터를 그대로 늘려서, 빈 구역에 채워넣는 것입니다.
1 2 3 4 ->
1 1 2 2 1 1 2 2 3 3 4 4 3 3 4 4 참고 : upsampling 간단 설명
-
Concatenate: default axis=-1
U-Net 구조는 다음과 같다고 한다. 뭐 이렇게 해보면 이런 구조대로 잘 되지 않을까 ㅋㅋ😅
def build_generator_unet(self):
def downsample(layer_input, filters, f_size=4):
d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
d = InstanceNormalization(axis = -1, center = False, scale = False)(d)
d = Activation('relu')(d)
return d
def upsample(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same')(u)
u = InstanceNormalization(axis = -1, center = False, scale = False)(u)
u = Activation('relu')(u)
if dropout_rate:
u = Dropout(dropout_rate)(u)
u = Concatenate()([u, skip_input])
return u
# Image input
img = Input(shape=self.img_shape)
# Downsampling
d1 = downsample(img, self.gen_n_filters)
d2 = downsample(d1, self.gen_n_filters*2)
d3 = downsample(d2, self.gen_n_filters*4)
d4 = downsample(d3, self.gen_n_filters*8)
# Upsampling
u1 = upsample(d4, d3, self.gen_n_filters*4)
u2 = upsample(u1, d2, self.gen_n_filters*2)
u3 = upsample(u2, d1, self.gen_n_filters)
u4 = UpSampling2D(size=2)(u3)
output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)
return Model(img, output_img)
resnet 으로 만든 generator 은 생략하였다 😅 집에 가고 싶다 ㅠ
build discriminator
discriminator 은 훨씬 더 간단하다
def build_discriminator(self):
def conv4(layer_input,filters, stride = 2, norm=True):
y = Conv2D(filters, kernel_size=(4,4), strides=stride, padding='same', kernel_initializer = self.weight_init)(layer_input)
if norm:
y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
y = LeakyReLU(0.2)(y)
return y
img = Input(shape=self.img_shape)
y = conv4(img, self.disc_n_filters, stride = 2, norm = False)
y = conv4(y, self.disc_n_filters*2, stride = 2)
y = conv4(y, self.disc_n_filters*4, stride = 2)
y = conv4(y, self.disc_n_filters*8, stride = 1)
output = Conv2D(1, kernel_size=4, strides=1, padding='same',kernel_initializer = self.weight_init)(y)
return Model(img, output)
train discriminator
buffer : “너가 예전에 이렇게 말했었어 !” discriminator 을 더 잘 훈련시키기 위한 과거 기억 저장
-
self.buffer_B : generator 이 A 도메인 이미지를 받아서 B 도메인 이미지로 변환한 것의 모음
fake_B = self.g_AB.predict(imgs_A)
-
self.buffer_A : generator 이 B 도메인 이미지를 받아서 A 도메인 이미지로 변환한 것의 모음
fake_A = self.g_BA.predict(imgs_B)
-
random.sample(컬렉션,샘플 수): random 모듈에서 sample(컬렉션, 샘플수) 함수는 지정된 컬렉션으로부터 샘플수만큼 랜덤 추출을 하는 함수
-
model.train_on_batch : as the name implies, trains only one batch
Returns
Scalar training loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics).
def train_discriminators(self, imgs_A, imgs_B, valid, fake): # Translate images to opposite domain fake_B = self.g_AB.predict(imgs_A) fake_A = self.g_BA.predict(imgs_B) self.buffer_B.append(fake_B) self.buffer_A.append(fake_A) fake_A_rnd = random.sample(self.buffer_A, min(len(self.buffer_A), len(imgs_A))) fake_B_rnd = random.sample(self.buffer_B, min(len(self.buffer_B), len(imgs_B))) # Train the discriminators (original images = real / translated = Fake) dA_loss_real = self.d_A.train_on_batch(imgs_A, valid) dA_loss_fake = self.d_A.train_on_batch(fake_A_rnd, fake) dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake) dB_loss_real = self.d_B.train_on_batch(imgs_B, valid) dB_loss_fake = self.d_B.train_on_batch(fake_B_rnd, fake) dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake) # Total disciminator loss d_loss_total = 0.5 * np.add(dA_loss, dB_loss) #[0] : loss , [1] : accuracy return ( d_loss_total[0] , dA_loss[0], dA_loss_real[0], dA_loss_fake[0] , dB_loss[0], dB_loss_real[0], dB_loss_fake[0] , d_loss_total[1] , dA_loss[1], dA_loss_real[1], dA_loss_fake[1] , dB_loss[1], dB_loss_real[1], dB_loss_fake[1] )
train generator
outputs=[ valid_A, valid_B,reconstr_A, reconstr_B,img_A_id, img_B_id ]
이었으므로 loss 6개 나옴 ! gan loss 2개, cycle loss 2개, identity loss 2개
def train_generators(self, imgs_A, imgs_B, valid):
# Train the generators
return self.combined.train_on_batch([imgs_A, imgs_B],
[valid, valid,
imgs_A, imgs_B,
imgs_A, imgs_B])
최종 train 함수
def train(self, data_loader, run_folder, epochs, test_A_file, test_B_file, batch_size=1, sample_interval=50):
start_time = datetime.datetime.now()
# Adversarial loss ground truths
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
for epoch in range(self.epoch, epochs):
for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch()):
d_loss = self.train_discriminators(imgs_A, imgs_B, valid, fake)
g_loss = self.train_generators(imgs_A, imgs_B, valid)
elapsed_time = datetime.datetime.now() - start_time
# Plot the progress
print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
% ( self.epoch, epochs,
batch_i, data_loader.n_batches,
d_loss[0], 100*d_loss[7],
g_loss[0],
np.sum(g_loss[1:3]),
np.sum(g_loss[3:5]),
np.sum(g_loss[5:7]),
elapsed_time))
self.d_losses.append(d_loss)
self.g_losses.append(g_loss)
# If at save interval => save generated image samples
if batch_i % sample_interval == 0:
self.sample_images(data_loader, batch_i, run_folder, test_A_file, test_B_file)
self.combined.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (self.epoch)))
self.combined.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
self.save_model(run_folder)
self.epoch += 1
모델 구성 코드 보기도 버거워서 그 이후 부가적인 save 나 이미지 plotting 부분은 생략하였다 ! 🤪