Direct Preference Optimization: Your Language Model is Secretly a Reward Model
키워드 | LLM |
---|---|
year | 2023 |
저자 | Rafael Rafailov et al. |
Venue | ArXiv |
Memo | DPO. |
분류 | 연구 |
DONE | |
생성 일시 | |
최종 편집 일시 | |
Working |
@article{Rafailov2023DirectPO,
title={Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
author={Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
journal={ArXiv},
year={2023},
volume={abs/2305.18290},
url={https://api.semanticscholar.org/CorpusID:258959321}
}
Published in arXiv.org 29 May 2023
Introduction
Why Reinforcement Learning from Human Feedback(RLHF)?
Large unsupervised LM의 문제 : Training data unreflect Human Preference
LLM이 학습에 사용하는 training data는 사람에 의해 다양한 목표, 우선순위, 기술집합들(goals, priorities, skillsets)로 생성된다. 하지만 이러한 데이터는 목표나 기술집합에 따라서 LLM이 모방하기에 적절하지 않은 정보들을 포함하기도 한다.
예시
- 우리는 AI coding assistant를 사용할 때 일반적인 프로그래밍 실수를 이해하고 그 실수를 수정해주는 것을 기대하면서도, 코드를 생성할 때 이 모델을 training data에 존재하는 고품질의 코딩 능력에 편향 시키는 것을 원한다.
- LLM이 50%의 사람들이 믿는 잘못된 생각에 주의 하길 바랄 때, 우리는 그 모델이 이 잘못된 생각에 대한 query의 50%만 맞다고 답하기를 원하지 않는다.
위의 예시가 잘 와닿지 않아서 다른 블로그 포스팅 예시 참고
Training data에는 욕설이나 편향적 발언, 부정확한 정보 같은 부적절한 데이터도 다수 포함되어 있다. 물론 정제 및 필터링 과정을 통해서 대부분의 부적절한 데이터를 training data에서 제거하지만, 사람의 기준에서 적절하지 않은 데이터를 모두 제거하기는 어렵다. 그래서 모델은 문맥에 따라 적절하지 않은 문장을 생성하는 경우가 발생하게 된다. 또한, LLM은 Next-token-prediction 방식 (주어진 문맥에 대해 다음에 나올 토큰들을 예측하는 방식)으로 학습하기 때문에, 생성모델은 학습된 정보에 기반 하여 가장 ‘그럴듯한’ 문장을 생성한다. 때문에 LLM은 사람의 의도가 반영되지 않은 비윤리적 답변이나 환각 현상(Hallucination) 등의 문제를 겪게 된다.
Match Human Preferences using Reinforcement Learning(RL)
LLM을 적절하게 활용하기 위해서 사람이 의도한 방향에 맞게 모델을 통제할 수 있어야 한다.
이를 위해서 PLM들은 Supervised Fine-tuning(SFT) 방식과 Reinforcement Learning from Human Feedback (RLHF) 방식을 통해 안전하고 유용한 모델로 만든다.
- SFT
주어진 문맥에 모범 답안을 제공해서 모델이 올바른 답변을 모사하도록 학습
- Human Preference Alignment (Learning from Human Feedback)
주어진 문맥에 대해 모델이 답변을 생성하면, 답변에 대한 피드백을 제공해서 사람의 선호도를 학습. 현재 가장 대표적인 방법 : RLHF.
- Unsupervised Learning (Pre-training): 사전 학습을 통해서 대형 생성 모델(PLM)을 만듭니다. 대형 생성 모델은 길들여지지 않은 괴물과 같이 거대하고 강력하지만 사람이 원하는 의도대로 동작하기 어려운 경우가 많기 때문에 서비스에 바로 적용하기에는 어렵습니다.
- Supervised Fine-tuning: 특정 도메인의 데이터 혹은 크라우드 소싱 등을 통해 구축한 양질의 (Prompt, Response) 데이터를 구축하여 fine-tuning하는 과정입니다. 이를 통해 입력 프롬프트에 대해 사람의 의도에 맞는 문장을 생성하는 방법을 학습합니다.
- RLHF: SFT 모델에 추가적으로 강화 학습을 적용하여 사람의 의도에 맞게 파인튜닝을 하는 과정입니다.
Direct Preference Optimization(DPO)
위에서 소개한 RLHF는 놀라운 대화 능력을 만들어내지만, RLHF의 구조는 supervised learning 보다 복잡해서 심각한 computational costs를 야기한다. 그래서 최근, RLHF의 구현 복잡성 및 학습 불안정성 등의 문제를 보완하기 위해 RLHF를 대체하는 방법론들이 제안되고 있다. 이 논문에서는 RLHF를 대체하는 방법으로 human preference를 RL을 거치지 않고 바로 모델에 최적화시키는 방법을 제안한다.
DPO는 RLHF의 objective와 동일한 objective를 함축적으로(implicitly) 최적화 하지만 간단하고 학습에 수월한 알고리즘이다.
DPO
- 동적인 per-example importance weight를 통합해서 model degeneration을 방지한다.
- 현존 기법처럼 theoretical preference model (ex. Bradley-Terry model) 에 의존하지만, 현존 기법들과 달리 policy training을 위해 reward model을 사용하지 않고, variables를 바꿔 preference loss를 바로 policy function으로 정의한다. 덕분에 DPO는 reward model 훈련 없이 간단한 binary cross entropy objective를 통해 policy를 최적화 할 수 있다.
Contribution
- DPO : simple RL-free algorithm for training LM for preferences.
Methodology
- DPO의 목표 :
Preference 을 직접적으로 사용해서 간단하게 policy optimization을 얻는 것.
(derive a simple approach for policy optimization using preferences directly.)
- Key insight :
분석적인 mapping (Reward functions → Optimal policies) 활용을 통해 reward function을 사용한 loss function을 policy 사용 loss function으로 전환.
Deriving the DPO objective.
- RL objective :
- Optimal solution to the KL-constrained reward maximization objective in Eq. 3 :
- Partition function
partition function은 와 에 대한 함수이지만, Policy 에 independent 하다.
- Partition function
- RL objective(eq.3) 을 policy에 대한 식으로 바꾸기
- valid probability distribution for all .
- 는 에 대한 함수가 아님.
- 가 에 대한 함수가 아니므로 eq.14는 KL term을 최소화 하는 것과 같다.
- Gibbs’ inequality 에 따라 두 분포가 동일하면 KL-divergence는 0에서 최소화 된다.
따라서 모든 에 대해,
- Reward model 없는 reward function 얻기
Eq.4 양변에 log를 취한다.
Eq.5 를 ground-truth reward 와 optimal model 에 적용할 수 있다.
- Preference model under Bradley-Terry model
- 상호 비교 데이터를 모델링 하기 위한 확률적인 모델 중 하나이다. 두 개의 대상간의 상대적인 강도를 추정하고, 이를 통해 전체 모집단에서 순위를 예측하는 데 사용된다. 이 모델은 MLE나 Baysian 방법을 사용해서 파라미터를 추정할 수 있다.
- 두 completions에 대한 rewards 의 차이에만 의존한다.
- eq.1 에 eq.5를 적용해서 파라미터를 전환 시키면, human preference probability를 reward model term 없이 표현할 수 있다.
- 따라서 BT 모델에 기반한 optimal RLHF policy 는 다음 preference model을 만족한다 :
- Maximum likelihood objective for parametrized policy (language model policy).
- reward model(eq.2) 접근과 유사하게, policy objective :
- pytorch code for DPO loss
def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta): """ pi_logps: policy logprobs, shape (B,) ref_logps: reference model logprobs, shape (B,) yw_idxs: preferred completion indices in [0, B-1], shape (T,) yl_idxs: dispreferred completion indices in [0, B-1], shape (T,) beta: temperature controlling strength of KL penalty Each pair of (yw_idxs[i], yl_idxs[i]) represents the indices of a single preference pair. """ pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs] ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs] pi_logratios = pi_yw_logps - pi_yl_logps ref_logratios = ref_yw_logps - ref_yl_logps losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios)) rewards = beta * (pi_logps - ref_logps).detach() return losses, rewards
- reward model(eq.2) 접근과 유사하게, policy objective :
- 이 방법은 재파라미터화 된 BT model을 fitting 한 것과 동일하기 때문에, preference data distribution의 적절한 가정 하에서 consistency 같은 이론적 특징을 갖는다.
What does the DPO update do?
DPO의 매커니즘 이해를 위해 loss function 의 gradient를 분석해보자면,
- reward, language model 과 reference model 에 의해 함축적으로 정의된 reward
직관적으로 loss function 의 gradient는 선호되는 completions 의 likelihood를 증가시키고 비선호 되는 completions 의 likelihood를 감소시킨다. 주목해야할 부분은 함축된 reward model 이 를 얼마나 높이 평가했는지, 즉, reward model이 얼마나 잘못 추론했는지에 따라 examples에 가중치가 매겨진다. 이렇게 각 example마다 weight를 적용해서 LM이 degenerate 하는 것을 방지했다.
DPO outline
DPO pipeline
- Sample completions for every prompt , label with human preferences for offline dataset of preferences .
- Optimize the LM to minimize .
In practice, reuse preference datasets
- When available, .
- When unavailable, .
이 과정은 분포가 unavailable한 true reference 분포와 DPO가 사용한 사이의 shift를 완화해준다.