REALM: Retrieval-Augmented Language Model Pre-Training

REALM: Retrieval-Augmented Language Model Pre-Training
2020.02

  • retriever 도 학습을 하면 QA 성능이 매우 높아짐
  • retriever 와 reader를 한번에 학습할 수 있음
  • retriever를 pretraining에서 수행하는 모델 제안

main contribution

  • retriever 와 reader 를 한번에 학습하는 end-to-end 모델
  • 쿼리를 넣어 답을 찾는 과정을 두 단계로 분리
    • neural knowledge retriever

abstract

REALM의 핵심: unsupervised text의 performance-based signal을 사용하여 retriever를 학습시킨다.
가장 적절한 문서를 검색하는 것이 주요 목적
비지도 학습 방식으로는 처음으로 사전 학습하는 방법을 제시: MLM, retrieval step을 역전파 학습
언어모델의 perplexity를 개선하는 검색은 보상을 주고 정보가 없는 검색은 패널티를 주는 방식
Open-domain Question Answering 태스크에서 효과적, sota \

일반적으로 딥러닝 모델은 역전파(backpropagation)라는 방법을 사용해서 학습
모델의 출력을 올바른 정답과 비교하고, 그 차이를 줄이도록 모델의 내부 가중치(parameters)를 조정하는 방식
그런데 REALM에서는 단순히 신경망의 매개변수만 학습하는 게 아니라, 문서를 검색하는 과정(retrieval step)도 함께 학습한다. 즉, 모델이 “어떤 정보를 검색해야 하는지”까지 최적화하는 것 \

method

  • 알맞은 문서인지 평가하는 방법: MIPS(Maximum Inner Product Search)
    • document z와 Input x의 임베딩 벡터 간의 내적값을 계산
  • 알맞은 retrieve 방법을 latent variable language model로 모델링
    • marginal likelihood를 최적화하여 학습
  • Knowledge Augmented Encoder
    • 외부지식을 추가적으로 활용하는 인코더

예를 들어 위의 그림에서 model이 “the ___ at the top of the pyramid”의 빈칸을 채워야하는 경우, retriever는 “The pyramidion on top allows for less material higher up the pyramid.”라는 document를 선택하면 보상을 받는다. \ (Retrieved Document, Query and document 사이)

검색 과정을 하나의 확률적 선택 과정으로 보고, 학습을 통해 점점 더 좋은 검색 결과를 찾을 수 있도록 만듦.

REALM’s generative process

REALM은 입력 x가 주어지면 가능한 출력 y에 대한 확률분포 p(y|x)를 학습

  • pre-training 단계
    • MLM(Masked Language Modeling) 수행
    • 입력 x는 일부 단어가 가려진 문장
    • 모델은 가려진 단어 y를 예측하는 작업을 학습
  • fine tuning 단계
    • Open Domain QA 문제를 학습
    • 입력 x는 질문, 출력 y는 정답

어떤 문서 z를 참고해야 할지를 확률적으로 결정하는 과정

확률 분포 \(p(y|x)\)를 두 단계로 분해하여 학습: retrieve–then-predict를 공식화

  • retrieve step (문서 검색)
    • 입력 x가 주어지면 지식 코퍼스 Z에서 관련 문서 z를 검색
    • \[p(z|x)\]
  • predict step (답변 생성)
    • 검색된 문서 z와 원래 입력 x를 바탕으로 답변 y를 생성
    • \[p(y|z,x)\]
  • 이때 어떤 문서 z가 가장 좋은 문서인지 미리 알 수 없기 때문에 모든 가능한 문서 z에 대해 확률을 합산(marginalization)

=> 먼저 여러 개의 문서 z 를 검색하고, 각 문서별로 답변을 생성한 뒤, 각 문서에 대한 확률 \(p(z∣x)\)을 곱해서 최종적으로 가장 가능성 높은 정답을 고르는 과정

Knowledge Retriever: \(p(z|x)\) 정의

  1. 검색 모델 정의 Retrieval 과정은 내적을 기반으로 문서와 질문의 관련성을 측정하는 모델을 사용: Dense Inner Product Model
    • \[f(x,z)\]
      • 입력 x와 문서 z 사이의 유사도 점수
      • softmax 분포: 문서 전체에 대해 유사도를 계산한 후 softmax를 적용하여 확률분포로 변환
    • 유사도 점수 계산
      • Embed_input(x): 입력 x를 벡터로 변환하는 함수
      • Embed_doc(z): 문서 z를 벡터로 변환하는 함수
      • 둘을 내적

  1. 임베딩 방법 BERT 기반 Transformer를 사용하여 질문과 문서를 벡터로 변환
    • 입력 질문을 BERT 형식으로 변환
    • 두 개의 텍스트를 하나로 합침

BERT 모델에 통과시킨 후 선형변환을 적용하여 최종 임베딩 변환

  • 입력 (질문) 임베딩 변환
  • 문서 (후보 문서) 임베딩 변환
  • W: 차원을 줄이기 위한 선형 변환 행렬
  • 모델의 학습 가능 매개변수: BERT Transformer 파라미터, Projection Matrices

Knowledge-Augmented Encoder: \(p(y|z, x)\) 정의

입력 x와 검색된 문서 z를 하나의 시퀀스로 결합한 후 Transformer를 사용하여 답을 생성

  • pre-training의 경우
    • MLM 사용
    • Transformer의 출력 벡터와 단어 임베딩의 내적을 계산하여 MASK 토큰 예측
  • \(J\): 가려진 MASK 토큰 개수
  • \(y_j\): j번째 MASK 토큰이 원래 갖고 있던 단어

입력 x와 검색된 문서 z를 하나의 시퀀스로 결합한 후 Transformer를 사용하여 답을 생성

  • fine-tuning의 경우
    • 답변이 검색된 문서 z 안에 존재한다고 가정
    • 정답이 연속된 단어(스팬)로 이루어져 있다고 가정
    • MLP(Multi-Layer Perceptron), 즉 피드포워드 신경망을 사용하여 정답 스팬 예측

\(h_{start}\): 시작 위치 예측
\(h_{end}\): 끝 위치 예측
\(S(z, y)\): 문서 z에서 정답 y와 일치하는 Span들의 집합 \

training

정답 y의 log-likelihood log p(y|x)를 최대화
발생한 이슈들 2가지 소개 \

  1. 지식 검색 과정 계산 문제
  • 모든 확률 합산: 문서 전체 집합 Z가 크면 계산량이 너무 많음
  • 해결: 검색 확률이 높은 상위 k개 문서만 고려
  1. 검색을 위한 MIPS
    • 효율적으로 사용하기 위해 모든 문서의 벡터를 미리 만들어서 빠르게 검색할 수 있도록 인덱싱 해야 함
    • 모든 문서 z를 벡터로 변환해서(Embed_doc(z)) MIPS 검색 인덱스에 저장
    • 하지만 학습이 계속될수록 문서 임베딩 함수의 파라미터가 계속 업데이트되기 때문에 최신 모델과 불일치(stale) 문제 발생
    • 해결: 검색 인덱스 주기적 업데이트

experiment

T5는 REALM과 달리 사전학습에서 SQuAD의 추가 MRC dataset에 접근한다는 점을 유의

결론

  • 기존의 언어모델보다 30배 작은 크기
  • 검색 과정도 학습 가능한 방식으로 통합 (end-to-end learning)
  • 검색과 답변 생성을 하나의 신경망으로 연결
  • 검색 과정도 비지도 학습을 활용해 학습