딥러닝 이야기 / Deep Q Learning (DQN) / 2. DQN을 이용한 Cart Pole 세우기

DQN을 이용한 Cart Pole 세우기

작성자: 여행 초짜
작성일: 2023.03.02

시작하기 앞서 틀린 부분이 있을 수 있으니, 틀린 부분이 있다면 지적해주시면 감사하겠습니다.

이전글에서는 강화학습 딥러닝 모델의 시초격인 Deep Q Learning (DQN)에 대해 설명하였습니다. 이번글에서는 DQN을 이용한 cart pole 세우는 모델을 학습해보도록 하겠습니다. 본 글에서는 DQN cart pole 학습 방법에 중점을 두었습니다. 실제 학습된 모델을 평가하여 gif로 이미지를 만들거나, scheduler 등의 전체 코드는 GitHub에 올려놓았으니 아래 링크를 참고하시기 바랍니다.

본 DQN의 코드는 PyTorch와 Gym 라이브러리를 이용하여 구현되었습니다.

오늘의 컨텐츠입니다.

  1. 재현 메모리
  2. DQN 모델
  3. DQN 학습
  4. DQN 학습 결과

DQN 구현

재현 메모리
Replay Memory

여기서는 DQN의 핵심 아이디어 중 하나인 재현 메모리를 살펴보겠습니다. 기존 Q-table을 딥러닝으로 대체하려는 시도는 불안정한 경향이 있었기에 이를 해결하기 위해 도입한 것이 바로 재현 메모리 버퍼의 개념입니다. 이렇게 즉각적으로 모델이 선택한 action과 그에 대한 next state를 학습에 반영하지 않고 재현 메모리라는 버퍼에 담아둔 후 랜덤으로 batch 만큼 추출하여 학습하기 위한 코드입니다.

import torch
import random
from collections import deque


class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, state, action, reward, next_state):
        self.memory.append((state, action, torch.FloatTensor([reward]), torch.FloatTensor([next_state])))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


    
memory = ReplayMemory(10000)

  • 8번째 줄: 재현 메모리 버퍼의 크기 설정.
  • 10 ~ 11번째 줄: 모델이 선택한 action, action을 선택한 시점의 state, 이로 인해 계산되는 next state, reward를 저장하는 함수.
  • 13 ~ 14번째 줄: 학습을 위해 랜덤으로 데이터를 샘플링할 때 사용하는 함수.
  • 21번째 줄: 재현 메모리 버퍼 설정.

DQN 모델

이제 DQN의 모델을 구성하는 코드입니다. Cart pole 세우는 task는 그리 복잡하지 않기때문에 아주 간단한 MLP 모델로 구성합니다.

class DQN(nn.Module):
    def __init__(self, config, case_num, device):
        super(DQN, self).__init__()
        self.hidden_dim = config.hidden_dim
        self.case_num = case_num
        self.device = device
        self.model = nn.Sequential(
            nn.Linear(4, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, case_num)
        )


    def forward(self, x):
        x = x.to(self.device)
        x = self.model(x)
        return x

  • 4번째 줄: 모델의 hidden dimension.
  • 5번째 줄: 모델이 취할 수 있는 action의 종류 수. Cart pole에서는 좌, 우 2개.
  • 7 ~ 11번째 줄: DQN 모델의 MLP. 4라는 숫자는 모델이 현재 state [cart position, cart velocity, pole angle, pole angular velocity]의 4가지 종류의 상태를 받기 때문임.
  • 14 ~ 17번째 줄: 모델이 학습할 때 거치는 부분.

DQN 학습

이제 DQN을 학습하는 코드입니다. 아래 코드의 config.의 부분은 GitHub 코드에 보면 src/config.json이라는 파일에 존재하는 변수 값들을 모델에 적용하여 초기화 하는 것입니다. 그리고 self. 이라고 나와있는 부분은 GitHub 코드에 보면 알겠지만 학습하는 코드가 class 내부의 변수이기 때문에 있는 것입니다.

def select_action(self, state, phase='train'):
    eps_threshold = self.config.eps_end + (self.config.eps_start - self.config.eps_end) * math.exp(-1 * self.steps_done / self.config.eps_decay)
    self.steps_done += 1

    # choose bigger action between left and right
    if random.random() > eps_threshold:
        with torch.no_grad():
            return torch.argmax(self.q_net(state), dim=1, keepdim=True)
    
    # random action between left and right
    return torch.tensor([[random.randrange(self.case_num)]], dtype=torch.long).to(self.device)


def train(self):
    if len(self.memory) >= self.batch_size:
        batch = self.memory.sample(self.batch_size)
        states, actions, rewards, next_states = zip(*batch)

        # batch data and current q
        states = torch.cat(states).to(self.device)
        actions = torch.cat(actions).to(self.device)
        rewards = torch.cat(rewards).to(self.device)
        next_states = torch.cat(next_states).to(self.device)

        # finding current q and max q values
        curr_q = torch.gather(self.q_net(states), dim=1, index=actions)
        max_next_q, _ = torch.max(self.target(next_states), dim=1)

        # target q
        target_q = rewards + max_next_q * self.config.gamma
        target_q = target_q.detach()
        
        # training
        self.optimizer.zero_grad()
        loss = self.criterion(curr_q.squeeze(), target_q)
        loss.backward()
        self.optimizer.step()



# deine memory class
self.memory = ReplayMemory(10000)

# environment define
self.env = gym.make('CartPole-v1')
self.case_num = self.env.action_space.n

# model define
self.q_net = DQN(self.config, self.case_num, self.device).to(self.device)
self.target = DQN(self.config, self.case_num, self.device).to(self.device)
self.target.load_state_dict(self.q_net.state_dict())
self.target.eval()

# optimizer and loss function define
self.criterion = nn.SmoothL1Loss()
self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.lr)
self.steps_done = 0

# training
self.q_net.train()
for episode in range(self.episodes):
    state = self.env.reset()
    
    for t in count():
        state = torch.FloatTensor([state])
        action = self.select_action(state)

        next_state, reward, done, _ = self.env.step(action.item())

        if done:
            reward = -1
        
        # push to memory
        self.memory.push(state, action, reward, next_state)

        # update Q networks
        self.train()

        # update state
        state = next_state                

        if done:
            break
        
    if episode % self.config.target_update_duration == 0:
        self.target.load_state_dict(self.q_net.state_dict())
        self.target.eval()


self.env.render()
self.env.close()

Action 선택 함수

  • 1 ~ 11번째 줄: DQN 모델 결과로 action을 선택하는 코드.
  • 2 ~ 8번째 줄: steps_done이라는 변수를 바탕으로 학습이 많이 진행 되었을 때(학습이 많이 되어 DQN 모델의 신뢰도가 높을 때) threshold를 계산하여 모델이 선택한 action을 내어줌.
  • 10 ~ 11번째 줄: steps_done이라는 변수를 바탕으로 학습 초기일 때(학습이 덜 되어 DQN 모델의 신뢰도가 낮을 때) threshold를 계산하여 random 선택을 더 많이 하게 함. 즉 학습 초기에 random으로 여러가지 케이스 탐색을 하게 함.

DQN 모델 파라미터 업데이트 함수
  • 14 ~ 37번째 줄: 모델의 파라미터가 실질적으로 업데이트 되는 함수.
  • 15 ~ 17번째 줄: 재현 메모리에 저장된 데이터들을 batch size만큼 랜덤으로 추출.
  • 19 ~ 23번째 줄: 리스트로 된 데이터를 torch tensor로 변경.
  • 26번째 줄: 추출된 데이터 중 그 당시의 state에 대한 action 값(float value)을 가져오기 위해 현재 0, 1로 이루어진 이산 action 값을 바탕으로 위치 파악 후 action value를 가져옴.
  • 27번째 줄: Next state에 대한 action value를 가져옴.
  • 30 ~ 31번째 줄: 학습에 사용하기 위한 target value를 만들어줌.
  • 34 ~ 37번째 줄: Q network를 loss를 바탕으로 업데이트.

DQN 에피소드 부분
  • 42번째 줄: 최대 10,000의 데이터를 저장할 수 있는 replay memory 클래서 정의(재현 메모리에서 설명한 class).
  • 45 ~ 46번째 줄: gym 환경 정의 및 action case 개수 정의(좌우 2).
  • 49 ~ 52번째 줄: Q network와 target network를 정의. target network는 Q network와 동일하게 초기화
  • 55 ~ 57번째 줄: Optimizer, loss function 정의.
  • 60 ~ 91번째 줄: 학습 Episode가 일어나는 부분.
  • 62번째 줄: 학습 환경 초기화. 어느 수준 랜덤으로 cart pole 위치 및 각도가 초기화 되는 부분.
  • 65 ~ 66번째 줄: 현재 상태에 대한 다음 취할 action을 선택하는 부분.
  • 68번째 줄: 선택한 action을 취했을 때 나타나는 next state, reward, done (cart pole이 넘어졌는지 여부) 데이터를 내어주는 부분.
  • 70 ~ 71번째 줄: Cart pole이 넘어졌으면 reward를 -1로 설정(gym 라이브러리에서는 모든 reward가 1로 설정되어있기 때문).
  • 74번째 줄: 이렇게 얻은 데이터를 재현 메모리에 저장.
  • 77번째 줄: 재현 메모리에 쌓인 데이터를 바탕으로 Q network 학습.
  • 82 ~ 83번째 줄: 만약 cart pole이 넘어졌으면 episode 중단.
  • 85 ~ 87번째 줄: 일정 주기마다 target network를 최신화 해주기 위해 Q network의 파라미터로 clone.

DQN 학습 결과

아래는 학습 episode별로 cart pole이 버틴 step (duration) 수입니다. Episode가 50보다 작을 때는 20 duration도 버티기 버거웠지만 학습이 진행될 수록 오래 버티는 것을 알 수 있습니다. 그리고 gym 라이브러리의 최대 duration 상한선이 500으로 설정 되어있어서, 그 보다 더 버텼더라도 학습이 조기 종료 됩니다. 따라서 최대 duration이 500을 넘어가지 못하는 것을 아래 그래프에서 확인할 수 있습니다.

학습 에피소드별 cart pole 버틴 duration 수


아래는 episode 중 500 step을 버틴 cart pole 강화학습 결과입니다. 학습이 잘 된 것을 확인할 수 있습니다.

Cart pole 학습 결과, 500 duration




지금까지 DQN을 이용하여 cart pole을 강화학습 해보았습니다. 이렇게 간단한 task는 DQN만 사용하더라도 상당히 잘 작동하지만 복잡한 task에 대해서는 여전히 불안정한 감이 있습니다. 여담으로 OpenAI에서 개발한 강화학습 알고리즘인 Proximal Policy Optimization (PPO)는 간단하고 성능도 좋아 최근에 가장 많이 쓰는 기법 중 하나입니다.

태그 #DQN #CartPole #gym
⟨ 이전글
Deep Q Learning (DQN)