본문 바로가기
개발/📑 논문 리뷰

[논문 리뷰] RS-DPO(Rejection Sampling DPO)

by 썸머뮤트 2024. 11. 19.

최근에 RLHF, DPO 같은 모델을 학습하는 방법론에 대해 공부해보는 중인데요,

오늘은 올해 초 아마존에서 제안한 RS-DPO(Rejection Sampling Direct Preference Optimization)에 대해 소개해드리려고 합니다. 

RLHF와 DPO 관련 내용이 궁금하시다면 해당 포스팅 1(RLHF), 2(DPO)를 참고 해보시면 좋을 것 같습니다!

 

RS-DPO란?

RLHF와 DPO 둘 모두 사람의 선호도를 기반으로 언어 모델을 튜닝해서 조금 더 "사람처럼 답변하는" 또는 "사람이 선호하는 답변을 생성하는" Alignment 과정을 위한 학습방법론입니다.

RLHF는 데이터 수집 및 학습에 비용이 많이 들고, 강화학습 방법론의 일종이다 보니 학습이 불안정 할 수 있다는 단점이 있었습니다.

이 한계를 개선한 방법론이 DPO이였습니다.

RS-DPO는 Rejection Sampling(RS)을 DPO와 결합해서 DPO를 더 개선한 하이브리드 학습 방법론 입니다.

 

Rejection Sampling이란?

일반적으로 말하는 ML에서의 Rejection sampling과 RS-DPO에서의 Rejection 샘플링은 엄밀히 말하면 동일하지 않은 것 같네요.

RS-DPO의 rejection은 일단 샘플링 해놓고 기각한다는 관점에서 RS라는 이름을 쓴 것 같고, ML에서 일반적으로 말하는 RS는 따로 있는 것 같은데, 흥미로워 일단 정리해봤습니다.

해당 Section은 관심 있는 분들만 읽어보셔도 될 것 같습니다!

 

RS(Rejection Sampling)는 확률분포에서 샘플을 생성하는 방법 중 하나로, 복잡한 확률분포에서 직접 샘플링하기 어려운 경우에 사용됩니다.

RS는 단순한 샘플링 분포를 사용해서 샘플을 생성해내고, 거부(rejection) 과정을 통해 필요한 샘플을 선택하는 방법입니다.

 

여담이지만, 이를 통해서 정확히 알 수 없는 $p(x)$분포에 해당하는 샘플을 얻어낼 수 있다는 것인데, 진짜 다양한 분야에 활용될 수 있을 것 같네요... 흥미로워요....

 

Rejection Sampling 진행 방법

먼저 용어부터 정리를 해보도록 하겠습니다.

  • $p(x)$: 샘플링하고자 하는 복잡한 확률 분포
  • $q(x)$: 샘플링이 더 쉬운 제안 분포  ex) 가우시안 분포
  • $M$: 스케일링 상수 $p(x) \leq M*q(x)$ 를 언제나 만족해야 합니다.

RS는 두 단계를 반복하며 진행합니다.

  1. 샘플 생성
    먼저 샘플링이 쉬운 분포 $q(x)$에서 샘플 $x_0$를 생성합니다.
    그리고 유니폼 Distribution에서 $u \sim U(0, 1)$를 샘플링 합니다.
  2. 수락 혹은 거부
    만약 $u < \frac{p(x_0)}{M*q(x_0)}$ 이라면 $x_0$을 수락하고, 아니라면 다시 1번 단계로 돌아가 새로운 샘플을 생성합니다.

수학적으로 이 과정을 무수히 반복하면 $q(x)$에서 샘플링된 $x$들이 $p(x)$의 분포를 따른다고 합니다.

 

RS-DPO 적용 단계

자 이제 원래 목적으로 돌아와서 RS-DPO는 어떻게 동작하는지 알아보도록 하겠습니다.

1. SFT가 적용된 모델 준비

1단계는 언제나 그랬듯 초기 답변을 생성할 수 있는 모델을 준비하는 과정입니다.

RS-DPO에서는 목표로 하는 task에 대해 SFT(Supervised Fine-tuning)가 적용된 모델을 준비하라고 합니다.

제 경험상으로는 (파운데이션 모델 + SFT + 선호도 학습)의 조합보다 파운데이션 모델에서 바로 선호도 학습을 적용하는게 성능이 더 좋은 경우가 종종 있었는데, RS-DPO에서는 SFT를 미리 적용하고 선호도를 학습하는걸 아예 상정하고 있네요!

 

2. 리워드 모델 학습

사용자의 질문과 답변이 들어갔을 때 점수를 예측할 수 있는 리워드 모델을 학습해서 준비합니다.

요즘은 이미 구축된 사람의 선호도 데이터 셋이 존재해서 리워드 모델 학습을 비교적 쉽게 할 수 있는 것 같네요.

 

3.  Preference 데이터 구축 및 학습

이제 1번 2번에서 구축한 모델을 활용해 선호도 데이터 셋을 (이론상)무한 생성 할 수 있습니다!

1번에서 준비한 모델로 하나의 프롬프트당 $k$개의 답변을 생성합니다.

그리고 $k \choose 2$조합을 모두 순회하면서 두 개의 답변에 대해 각각 리워드 모델이 예측한 리워드 점수를 측정합니다.

두 리워드 점수의 차이가 $\eta$ 보다 크다면 점수가 작은 쪽을 losing, 큰 쪽을 winning으로 설정하여

(프롬프트, winning 답변, losing 답변)의 쌍으로 이루어진 데이터 셋을 구축합니다.

$\eta$를 기준으로 점수 차이를 필터링 하는 과정 때문에 Rejection Sampling이라는 이름을 붙인 것 같네요.

 

4. DPO 학습

3번에서 구축된 데이터 셋을 활용해서 DPO의 손실함수를 그대로 적용해 1번에서 준비했던 SFT 모델을 학습합니다.

 

생각보다 방법이 아주 간단하죠?

 

RS-DPO의 장점

방법론 자체는 간단하지만 RS-DPO는 아래와 같은 장점이 있습니다.

  1. 안정성: 리워드 모델의 품질 변화에 대해 안정적이고 강건합니다. 논문에서는 이를 실험적으로 증명을 해주었는데, 더 많은 데이터로 잘 학습된 리워드 모델과, 성능이 비교적 떨어지는 리워드 모델을 활용해서 데이터 셋을 구축했을 때 모두 일반적인 PPO를 사용했을 때보다 성능이 좋았다고 합니다. 또한 PPO를 사용했을 때 최종 모델의 성능이 리워드 모델 성능에 대한 의존도가 높다는 것도 실험적으로 보여주었습니다.
  2. 효율성: RLHF의 경우 GPU 메모리에 리워드 모델도 함께 올려 두어야 하는데, RS-DPO는 미리 데이터셋을 만들어 두고 해당 데이터 셋으로 학습을 진행하므로 제한된 리소스 환경에서 효과적으로 언어 모델을 학습할 수 있습니다.
  3. 성능: 실험 결과, PPO, DPO 등 기존 방법들보다 우수한 성능을 보였습니다.

 

'개발 > 📑 논문 리뷰' 카테고리의 다른 글

[논문 리뷰] Segment Anything(SAM)  (1) 2024.10.29