본문 바로가기
개발/🦾 머신러닝, 딥러닝

DPO란? - RLHF를 개선한 모델 학습 방법론

by 썸머뮤트 2024. 11. 18.

DPO란?

오늘은 지난번에 알아본 RLHF(Reinforcement Learning with Human Feedback)의 효율화된 버전인 DPO(Direct Preference Optimization)에 대해 알아보겠습니다.

RLHF와 그 한계점에 알고싶으신 분들은 해당 포스팅을 참고해보시면 좋을 것 같습니다!

DPO는 RLHF와 마찬가지로 사람의 피드백을 활용하여 언어 모델을 개선하는 방법론입니다.

다만 RLHF의 비용적 한계를 개선하기 위해 좀 더 간단한 접근법을 제시합니다.

 

DPO는 아래와 측면에서 비용적 한계를 개선했습니다.

  1. 직접적인 선호도 최적화: 사용자의 선호도 데이터를 직접 활용하여 언어 모델을 최적화합니다.
  2. 리워드 모델 불필요: RLHF와 달리 별도의 리워드 모델을 학습할 필요가 없습니다.
  3. 간소화된 학습 과정: GPU를 많이 사용해아하는 복잡한 강화학습 단계를 거치지 않고 단순한 이진 분류 문제로 학습을 진행합니다.

 

DPO 적용 단계

DPO도 RLHF와 마찬가지로 아래 3단계로 구성됩니다.
지난 포스팅과 똑같이 재미있는 답변을 생성하는 모델을 만드는 상황이라고 가정해보겠습니다.

1. 파운데이션 모델 준비 혹은 SFT 적용 준비

먼저 초기 답변을 생성할 수 있는 모델을 준비해야합니다. 이 모델은 튜닝이 전혀 적용 되지 않은 Llama같은 파운데이션 모델이 될 수도 있고, 목표로 하는 task에 대해 SFT(Supervised Fine-tuning)가 적용된 모델일 수도 있습니다.

단순히 생각해보면, SFT를 적용한 모델을 사용하는게 더 좋을 것 같긴하지만, 상황에 따라 SFT를 적용하지 않은 그냥 파운데이션 모델을 사용하는게 더 성능이 좋을 때도 있습니다.(결국 해봐야 아는..)

재미있는 답변을 생성하는 모델을 만들기 위해서는 코미디언의 인터뷰 스크립트를 모아서 SFT를 적용해 볼 수도 있겠네요!

2. 선호 데이터 수집

1단계에서 준비된 모델에서 같은 질문에 대해 다양한 답변을 생성해내도록 합니다. Temperature를 높여 다양한 답변을 만들 수도 있고, 모델 자체를 여러개 준비해서 답변을 만들 수도 있겠죠!

그리고 사람 레이블러가 이를 보고 어떤 것이 더 좋은지 순위를 매깁니다.

여기까지는 RLHF와 완전히 똑같습니다. 하지만 다른 점은 바로 리워드 모델을 학습하지 않는다는 점입니다.

3. 로그 손실함수 최적화

이제 최종 튜닝을 적용할 모델을 로그 손실함수 기반으로 학습합니다. 손실함수는 다음과 같습니다.

$$L_{DPO}(\pi_{\theta};\pi_{ref}) = -E_{(x, y_{pref}, y_{other})\sim D}[\log{\sigma(\beta \log{\frac{\pi_{\theta}(y_{pref}|x) }{\pi_{ref}(y_{pref}|x)} } - \beta\log{\frac{\pi_{\theta}(y_{other}|x)}{\pi_{ref}(y_{other}|x)}}  )}]$$

 

$\pi_\theta$: 최적화 대상 모델

$\pi_{ref}$: 참조 모델

$x$: 입력 프롬프트

$y_{pref}$: 선호되는 응답

$y_{other}$: 선호되지 않는 응답

$\sigma$: 시그모이드 함수

 

상당히 복잡해 보이지만 해석해보자면 아래와 같습니다.

  1. 선호도 비교: 함수는 선호되는 응답($y_{pref}$)과 선호되지 않는 응답($y_{other}$)의 확률을 비교합니다.
  2. 로그 확률 차이: $\log \frac{\pi_\theta(y_{pref}|x)}{\pi_{ref}(y_{pref}|x)} - \log \frac{\pi_\theta(y_{other}|x)}{\pi_{ref}(y_{other}|x)}$ 부분은 최적화 대상 모델과 참조 모델 간의 로그 확률 차이를 계산합니다.
  3. 시그모이드 변환: 이 차이는 시그모이드 함수를 통과하여 0과 1 사이의 값으로 변환됩니다.
  4. 기대값 계산: 전체 데이터셋에 대해 이 값의 기대값을 계산합니다.

결론적으로 손실함수는 $y_{pref}$가 $y_{other}$ 보다 더 높은 점수를 가지도록 학습합니다.

 

DPO의 장점

DPO는 RLHF와 비교했을 때 아래와 같은 장점이 있습니다.

  1. 효율성: DPO는 사람이 라벨링한 선호도 데이터를 직접 학습하기 때문에, 기존의 보상 모델 학습보다 효율적입니다.
  2. 해석 가능성: 선호 데이터를 명시적으로 사용하므로 모델의 의사결정 과정을 해석하기 쉽습니다.
  3. 직접적 피드백 활용: 사람의 선호도를 바로 반영하므로, 강화학습에서의 보상 설계의 어려움을 줄일 수 있습니다.

 

DPO vs RLHF

DPO는 RLHF과 비슷한 목적을 가진 학습 방법론 이지만, 다음과 같은 차이가 있습니다!
(이제와서)찾아보니 재미있는 연구가 많이 있는 것 같은데, 다른 연구도 포스팅 해보려고 합니다 😊

특징 DPO  RLHF
손실함수 선호 데이터를 직접 최적화 보상 모델을 학습 후 정책을 최적화
학습 복잡도 상대적으로 낮음 상대적으로 높음
보상 모델 필요 여부 불필요 필수