딥러닝 이야기 / Bidirectional Autoregressive Transformers (BART) / 2. BART를 이용한 한국어 챗봇 모델 학습

작성일: 2023.04.25
시작하기 앞서 틀린 부분이 있을 수 있으니, 틀린 부분이 있다면 지적해주시면 감사하겠습니다.
이전글에서는 autoencoder의 학습을 이용한 BART 모델에 대해서 알아보았습니다.
이번에는 챗봇 챗로그와 일상 대화 데이터를 수집하여 챗봇을 학습했던 프로젝트에 대해 이야기 하겠습니다.
일상대화 데이터를 사용한만큼 반말 데이터가 대부분이라서 학습 후 반말로 답을 하는 것을 방지하기 위해 LSTM 기반 형태소 분석기를 이용하여 반말의 답변인 경우 존댓말로 바꾸어주는 후처리 알고리즘까지 적용했던 프로젝트입니다.
학습에 사용한 데이터는 공개하지 않을 것이며, 존댓말로 바꾸어주는 후처리 코드또한 공개하지 않겠습니다. 구현은 python의 PyTorch를 이용하였습니다.
그리고 BART에 대한 글은 이전글을 참고하시기 바랍니다.
본 글에서 설명하는 BART 학습 코드는 GitHub에 올려놓았으니 아래 링크를 참고하시기 바랍니다(본 글에서는 모델의 구현에 초점을 맞추고 있기 때문에 데이터 전처리, 토크나이저 학습 등 학습을 위한 전체 코드는 아래 GitHub 링크를 참고하시기 바랍니다).
오늘의 컨텐츠입니다.
- BART 구현
- BART 학습
- BART 학습 결과
BART
”
여기서는 BART를 불러오는 과정을 살펴보겠습니다. 본 코드에서는 pre-trained KoBART와 vanilla BART의 두 모델을 선택하여 학습할 수 있게끔 구성되어있습니다.
from transformers import BartForConditionalGeneration
# BART
class BART(nn.Module):
def __init__(self, config, tokenizer):
super(BART, self).__init__()
self.pretrained = config.pretrained
self.bart = BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v2')
if not self.pretrained:
bart_config = self.bart.config
self.bart = BartForConditionalGeneration(bart_config)
self.tokenizer = tokenizer
def make_mask(self, src):
mask = torch.where(src==self.tokenizer.pad_token_id, 0, 1)
return mask
def forward(self, src, trg):
enc_mask, dec_mask = self.make_mask(src), self.make_mask(trg)
output = self.bart(input_ids=src, attention_mask=enc_mask, decoder_input_ids=trg, decoder_attention_mask=dec_mask).logits
return output
- 8번째 줄: 한국어 text-CLIP 모델을 제작하기 때문에 pre-trained KoBERT를 불러옴.강건한 모델을 학습하기 위해 source용 BERT, target용 BERT 모델을 따로 두지 않고 하나의 모델에서 source, target의 representation을 모두 학습하도록 하기 위해 하나의 모델만 로드.
- 8 ~ 12째 줄: Pre-trained KoBART 혹은 vanilla BART 모델을 불러오는 부분.
- 22 ~ 24번째 줄: 텍스트 데이터의 [PAD] 토큰 마스킹 및 모델 학습 부분.
”
이제 BART 모델을 학습해보겠습니다.
아래 코드에 self. 이라고 나와있는 부분은 GitHub 코드에 보면 알겠지만 학습하는 코드가 class 내부의 변수이기 때문에 있는 것입니다.
여기서는 무시해도 좋습니다.
그리고 아래 학습 코드는 실제 학습 코드를 간소화한 것입니다. Scheduler 등 전체 학습 코드는 GitHub 코드를 참고하면 됩니다.
self.model = BART(self.config, self.tokenizer).to(self.device)
self.criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
for src, trg in self.dataloaders[phase]:
self.optimizer.zero_grad()
src, trg = src.to(self.device), trg.to(self.device)
with torch.set_grad_enabled(phase=='train'):
output = self.model(src, trg)
loss = self.criterion(output[:, :-1, :].reshape(-1, output.size(-1)), trg[:, 1:].reshape(-1))
loss.backward()
self.optimizer.step()
학습에 필요한 것들 선언
먼저 위에 코드에서 정의한 모델을 불러오고 학습에 필요한 loss function, optimizer 등을 선언하는 부분입니다.
- 1 ~ 3번째 줄: Loss function, transformer 모델 및 optimizer 선언.
모델 학습
- 5 ~ 13번째 줄: Cross entropy loss를 이용하여 모델 학습하는 부분.
”
BART 모델을 학습한 결과를 보겠습니다. 존댓말 변환기를 달지 않은 결과입니다. 그리고 같은 질문이라도 temperature, top-k, top-p 파라미터에 의해 다른 답변을 내어주는 것을 확인할 수 있습니다.
Q: 유재석 요즘 너무 재밌지 않아? A: 유재석은 너무 웃기고 재밌지. Q: 요즘 너무 피곤해서 구론산 쟁여놓고 먹고 있어 A: 응응 구론산 먹으면 몸이 좀 가벼워진다더라! Q: 요즘 너무 피곤해서 구론산 쟁여놓고 먹고 있어 A: 응 구론산 먹으면 몸 좋아진다던데 나도 사야겠어
지금까지 BART를 이용한 간단하지만 성능 좋은 챗봇 학습하는 방법을 알아봤습니다.
학습 과정에 대한 전체 코드는 GitHub에 있으니 참고하시면 될 것 같습니다.