딥러닝 이야기 / Generative Adversarial Network (GAN) / 2. Generative Adversarial Network (GAN) 구현 및 MNIST 생성

Generative Adversarial Network (GAN) 구현 및 MNIST 생성

작성자: 여행 초짜
작성일: 2022.04.01

시작하기 앞서 틀린 부분이 있을 수 있으니, 틀린 부분이 있다면 지적해주시면 감사하겠습니다.

이전글에서는 generative adversarial network (GAN)에 대해 설명하였습니다. 이번글에서는 linear layer로 이루어진 vanilla GAN의 구현에 대해 설명하도록 하겠습니다. 학습에 사용한 데이터는 MNIST 데이터를 사용하였으며, 구현은 python의 PyTorch를 이용하였습니다. 그리고 GAN을 학습하면서 GAN이 학습 epoch별로 생성한 이미지를 살펴보겠습니다.

그리고 GAN에 대한 설명은 이전글을 참고하시기 바랍니다. 이렇게 구현한 GAN의 코드는 GitHub에 올려놓았으니 아래 링크를 참고하시기 바랍니다(본 글에서는 모델의 구현에 초점을 맞추고 있기 때문에, 데이터 전처리 및 학습을 위한 전체 코드는 아래 GitHub 링크를 참고하시기 바랍니다).

오늘의 컨텐츠입니다.

  1. Vanilla GAN 구현
  2. GAN 학습
  3. GAN 학습 결과

GAN 구현

Vanilla GAN 구현

여기서는 기본적인 vanilla GAN의 구현 코드를 살펴보겠습니다. 코드는 PyTorch로 작성 되었으며, vanilla GAN을 학습하기 위해서는 linear layer로 이루어진 discriminator와 generator 두 모델이 필요합니다. GAN의 모델 자체는 단순하지만 GAN의 학습 방법이 다른 모델들과 다르기 때문에 주의깊게 살펴봐야 합니다. 즉 학습 방법이 GAN의 핵심이라고 볼 수 있습니다. 한 줄씩 자세한 설명은 코드 아래쪽에 설명을 참고하시기 바랍니다.

class Generator(nn.Module):
    def __init__(self, config:Config, color_channel:int):
        super(Generator, self).__init__()
        self.height = config.height
        self.width = config.width
        self.hidden_dim = config.hidden_dim
        self.noise_init_size = config.noise_init_size
        self.color_channel = color_channel

        self.generator = nn.Sequential(
            nn.Linear(self.noise_init_size, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.width*self.height*self.color_channel),
            nn.Sigmoid()
        )


    def forward(self, x):
        batch_size = x.size(0)
        output = self.generator(x)
        output = output.view(batch_size, -1, self.height, self.width)
        return output



class Discriminator(nn.Module):
    def __init__(self, config:Config, color_channel:int):
        super(Discriminator, self).__init__()
        self.height = config.height
        self.width = config.width
        self.hidden_dim = config.hidden_dim
        self.color_channel = color_channel

        self.discriminator = nn.Sequential(
            nn.Linear(self.width*self.height*self.color_channel, self.hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(self.hidden_dim, int(self.hidden_dim/4)),
            nn.LeakyReLU(0.2),
            nn.Linear(int(self.hidden_dim/4), 1),
            nn.Sigmoid()
        )


    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        output = self.discriminator(x)
        return output

Generator
먼저 generator 부분입니다. 여기서 나오는 config 부분은 GitHub 코드에 보면 config.json이라는 파일에 존재하는 변수 값들을 모델에 적용하여 초기화 하는 것입니다.

  • 4, 5번째 줄: 학습 이미지를 모두 같은 크기로 전처리 하였을 때의 세로 가로 크기.
  • 6번째 줄: hidden layer의 차원.
  • 7번째 줄: Generator의 생성 시작에 사용하는 noise 초기 차원.
  • 8번째 줄: 이미지 전처리를 하였을 때, color channel 수(흑백으로 처리를 했다면 1, 칼라로 처리 했다면 3).
  • 10 ~ 17번째 줄: 여러 linear layer를 가지는 generator를 정의.
  • 20 ~ 24번째 줄: Generator를 거치는 부분, 마지막에서는 데이터의 크기를 (batch size * channel size * height * width)로 변경.

Discriminator
다음은 discriminator 부분입니다. 여기서도 config 부분은 GitHub 코드에 보면 config.json이라는 파일에 존재하는 변수 값들을 모델에 적용하여 초기화 하는 것입니다.
  • 31, 32번째 줄: 학습 이미지를 모두 같은 크기로 전처리 하였을 때의 세로 가로 크기.
  • 33번째 줄: hidden layer의 차원.
  • 34번째 줄: 이미지 전처리를 하였을 때, color channel 수(흑백으로 처리를 했다면 1, 칼라로 처리 했다면 3).
  • 36 ~ 43번째 줄: 여러 linear layer를 가지는 discriminator를 정의, 마지막에 데이터가 진짜인지 가짜인지 1, 0으로 판별해야하므로 sigmoid activation function 사용.
  • 46 ~ 50번째 줄: Discriminator를 거치는 부분, sigmoid 함수를 거쳐서 0 ~ 1 사이의 값으로 반환.

GAN 학습

이제 GAN의 핵심이라고 할 수 있는 학습 방법의 코드를 통해 어떻게 학습이 이루어지는지 살펴보겠습니다. 아래 코드에 self. 이라고 나와있는 부분은 GitHub 코드에 보면 알겠지만 학습하는 코드가 class 내부의 변수이기 때문에 있는 것입니다. 여기서는 무시해도 좋습니다.

self.G_model = Generator(self.config, self.color_channel).to(self.device)
self.D_model = Discriminator(self.config, self.color_channel).to(self.device)
self.criterion = nn.BCELoss()
self.G_optimizer = optim.Adam(self.G_model..parameters(), lr=self.lr)
self.D_optimizer = optim.Adam(self.D_model..parameters(), lr=self.lr)

for epoch in range(self.epochs):
    print(epoch+1, '/', self.epochs)
    print('-'*10)
    
    for phase in ['train', 'val']:
        if phase == 'train':
            self.G_model.train()
            self.D_model.train()
        else:
            self.G_model.eval()
            self.D_model.eval()

        G_total_loss, D_total_loss, Dx, D_G1, D_G2 = 0, 0, 0, 0, 0
        for i, (real_data, _) in enumerate(self.dataloaders[phase]):
            batch_size = real_data.size(0)
            self.G_optimizer.zero_grad()
            self.D_optimizer.zero_grad()

            with torch.set_grad_enabled(phase=='train'):
                ###################################### Discriminator #########################################
                # training discriminator for real data
                real_data = real_data.to(self.device)
                output_real = self.D_model(real_data)
                target = torch.ones(batch_size, 1).to(self.device)
                D_loss_real = self.criterion(output_real, target)
                Dx += output_real.sum().item()

                # training discriminator for fake data
                fake_data = self.G_model(torch.randn(batch_size, self.noise_init_size)).to(self.device)
                output_fake = self.D_model(fake_data.detach())  # for ignoring backprop of the generator
                target = torch.zeros(batch_size, 1).to(self.device)
                D_loss_fake = self.criterion(output_fake, target)
                D_loss = D_loss_real + D_loss_fake
                D_G1 += output_fake.sum().item()

                if phase == 'train':
                    D_loss.backward()
                    self.D_optimizer.step()
                ##############################################################################################


                ########################################## Generator #########################################
                # training generator by interrupting discriminator
                output_fake = self.D_model(fake_data)
                target = torch.ones(batch_size, 1).to(self.device)
                G_loss = self.criterion(output_fake, target)
                D_G2 += output_fake.sum().item()

                if phase == 'train':
                    G_loss.backward()
                    self.G_optimizer.step()
                ##############################################################################################

            D_total_loss += D_loss.item() * batch_size
            G_total_loss += G_loss.item() * batch_size

학습에 필요한 것들 선언
먼저 위에 코드에서 정의한 모델을 불러오고 학습에 필요한 loss function, optimizer 등을 선언하는 부분입니다.

  • 1 ~ 5번째 줄: Loss function, generator, discriminator 모델 선언 및 각각의 optimizer 선언.
  • 19번째 줄: 학습에 필요한 변수 선언.

Discriminator 학습
다음은 discriminator 학습 부분입니다. 코드상에서는 26 ~ 45번째 줄에 해당하는 부분입니다. 그중에서도 27 ~ 32번째 줄은 실제 우리가 가지고 있는 데이터를 학습하는 부분이고, 34 ~ 40번째 줄은 generator가 생성한 데이터를 학습하는 부분입니다.
  • 28 ~ 31번째 줄: 실제 우리가 가지고있는 학습 데이터를 discriminator에게 1로 예측하게 해주는 부분.
  • 32번째 줄: Dx는 실제 학습 데이터를 discriminator가 어떻게 예측했는지 판단하는 척도가 됨. Dx가 1에 가까울수록 discriminator는 실제 데이터를 1로 잘 판단했다는 뜻이며, 이론적으로 학습이 진행되면 될수록 discriminator는 가짜 데이터와 진짜 데이터를 구분하지 못해야하므로 Dx의 값은 0.5에 가까워져야함.
  • 35번째 줄: Discriminator에게 판단을 맡길 generator가 가짜 데이터를 생성하는 부분.
  • 36번째 줄: Discriminator에게 generator가 생성한 가짜 데이터를 넘겨주는 부분. Generator와 이어지는 gradient를 끊기 위해 detach를 한 후 넣어줌.
  • 37 ~ 39번째 줄: Discriminator의 최종 loss를 계산하는 부분.
  • 40번째 줄: D_G1은 discriminator가 업데이트 되기 전, generator가 생성한 가짜 데이터를 어떻게 판단하는지 확인하는 척도. 학습 초기에는 discriminator가 가짜 데이터를 잘 구분하기 때문에 0에 가까움. 우리는 생성을 잘하는 generator를 만드는게 목적이므로 학습을 할수록 진짜같은 데이터를 만들어서 1에 가까워지는 것이 좋음.
  • 42 ~ 45번째 줄: Discriminator의 loss를 바탕으로 discriminator를 업데이트하는 부분.

Generator 학습
다음은 generator 학습 부분입니다. 코드상에서는 48 ~ 58번째 줄에 해당하는 부분입니다.
  • 50 ~ 52번째 줄: Generator가 생성한 가짜 데이터를 discriminator가 진짜로 구분하게끔 1로 학습시키는 부분.
  • 53번째 줄: D_G2는 위에서 업데이트가 된 discriminator가 현재 generator가 생성한 가짜 데이터를 어떻게 판단하는지 확인하는 척도가 됨. Discriminator가 위에서 가짜 데이터에 대해 학습 했기 때문에 D_G2는 D_G1에 비해 작은 값이 나옴.
  • 55 ~ 57번째 줄: Generator의 loss를 바탕으로 generator를 업데이트하는 부분.
  • 60 ~ 61번째 줄: Discriminator와 generator의 전체 loss를 구하는 부분.

GAN 학습 결과

아래는 generator가 초기에 생성한 데이터와 학습 중간, 마지막 학습에서 생성한 모습입니다. 1 epoch의 학습 결과 generator는 노이즈만 생성하는 것을 알 수 있지만 갈수록 깔끔한 숫자 형태의 이미지를 생성하는 것을 확인할 수 있습니다.

Generator가 생성한 이미지


아래는 generator가 생성하는 데이터 변화를 모든 epoch 별로 만든 이미지입니다.

Generator가 생성한 이미지 변화




지금까지 vanilla GAN 구현 코드와 GAN의 핵심인 학습 방법에 대해 살펴보았습니다. 학습 과정에 대한 전체 코드는 GitHub에 있으니 참고하시면 될 것 같습니다. 다음에는 GAN의 문제점을 해결하고, 이미지의 특징을 잘 파악하여 데이터를 생성하는 모델인 deep convolutional GAN (DCGAN)에 대해 알아보도록 하겠습니다.

태그 #VanillaGAN #MNIST
⟨ 이전글
Generative Adversarial Network (GAN)