From 1f0e9c9cdedb1e5ca360ea71190b99bc89df5e15 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 9 Jun 2023 12:48:23 +0000 Subject: [PATCH 1/2] KL Estimator --- tests/test_no_peft.py | 4 ++- tests/test_ppo_trainer.py | 4 ++- trl/trainer/ppo_trainer.py | 67 +++++++++++++++++++++----------------- 3 files changed, 44 insertions(+), 31 deletions(-) diff --git a/tests/test_no_peft.py b/tests/test_no_peft.py index 3190b7c85a..fca820106b 100644 --- a/tests/test_no_peft.py +++ b/tests/test_no_peft.py @@ -46,7 +46,9 @@ def __getitem__(self, idx): "ppo/loss/total", "ppo/policy/entropy", "ppo/policy/approxkl", - "ppo/policy/policykl", + "ppo/policy/approxkl_k1", + "ppo/policy/approxkl_k2", + "ppo/policy/approxkl_k3", "ppo/policy/clipfrac", "ppo/policy/advantages", "ppo/policy/advantages_mean", diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index e5a9ed1bab..19f6f7ce32 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -45,7 +45,9 @@ "ppo/loss/total", "ppo/policy/entropy", "ppo/policy/approxkl", - "ppo/policy/policykl", + "ppo/policy/approxkl_k1", + "ppo/policy/approxkl_k2", + "ppo/policy/approxkl_k3", "ppo/policy/clipfrac", "ppo/policy/advantages", "ppo/policy/advantages_mean", diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 1009839324..e4d7c6a97b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -685,8 +685,8 @@ def collator(data): all_stats.append(train_stats) if self.config.early_stopping: - policykl = train_stats["policy/policykl"] - early_stop = self._early_stop(policykl) + approxkl = train_stats["policy/approxkl"] + early_stop = self._early_stop(approxkl) if early_stop: break @@ -734,7 +734,7 @@ def collator(data): return stats - def _early_stop(self, policykl): + def _early_stop(self, approxkl): r""" Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and the optimization step is skipped. @@ -751,7 +751,7 @@ def _early_stop(self, policykl): if not self.config.early_stopping: return early_stop - if not self.is_distributed and policykl > 1.5 * self.config.target_kl: + if not self.is_distributed and approxkl > 1.5 * self.config.target_kl: self.optimizer.zero_grad() early_stop = True elif self.is_distributed: @@ -760,11 +760,11 @@ def _early_stop(self, policykl): # Wait for all processes to finish dist.barrier() - # all gather the policykl - dist.all_reduce(policykl, dist.ReduceOp.SUM) - policykl /= self.accelerator.num_processes + # all gather the approxkl + dist.all_reduce(approxkl, dist.ReduceOp.SUM) + approxkl /= self.accelerator.num_processes - if policykl > 1.5 * self.config.target_kl: + if approxkl > 1.5 * self.config.target_kl: self.optimizer.zero_grad() early_stop = True return early_stop @@ -1011,16 +1011,17 @@ def loss( values = values * mask rewards = rewards * mask - for t in reversed(range(gen_len)): - nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 - delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] - lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) + with torch.no_grad(): + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) - returns = advantages + values - advantages = masked_whiten(advantages, mask) - advantages = advantages.detach() + returns = advantages + values + advantages = masked_whiten(advantages, mask) + advantages = advantages.detach() vpredclipped = clip_by_value( vpreds, values - self.config.cliprange_value, values + self.config.cliprange_value @@ -1031,30 +1032,38 @@ def loss( vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask) vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).double(), mask) - ratio = torch.exp(logprobs - old_logprobs) + logratio = logprobs - old_logprobs + ratio = torch.exp(logratio) pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange) pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask) - pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).double(), mask) - loss = pg_loss + self.config.vf_coef * vf_loss - entropy = masked_mean(entropy_from_logits(logits), mask) - approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask) - policykl = masked_mean(old_logprobs - logprobs, mask) + with torch.no_grad(): + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).double(), mask) + entropy = masked_mean(entropy_from_logits(logits), mask) + # calculate approx_kl http://joschu.net/blog/kl-approx.html + approxkl_k1 = masked_mean(-logratio, mask) + approxkl_k2 = 0.5 * masked_mean(logratio ** 2, mask) + approxkl_k3 = masked_mean((ratio - 1) - logratio, mask) + # by default use `k3`, the estimator which has lowest variance and is unbiased + approxkl = approxkl_k3 + return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask) value_mean, value_var = masked_mean(values, mask), masked_var(values, mask) stats = dict( loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()), policy=dict( - entropy=entropy.detach(), - approxkl=approxkl.detach(), - policykl=policykl.detach(), - clipfrac=pg_clipfrac.detach(), - advantages=advantages.detach(), - advantages_mean=masked_mean(advantages, mask).detach(), + entropy=entropy, + approxkl=approxkl, + approxkl_k1=approxkl_k1, + approxkl_k2=approxkl_k2, + approxkl_k3=approxkl_k3, + clipfrac=pg_clipfrac, + advantages=advantages, + advantages_mean=masked_mean(advantages, mask), ratio=ratio.detach(), ), returns=dict(mean=return_mean.detach(), var=return_var.detach()), From f67eb2432be6c36b38a46d505d8a4b3cf692cf3d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 9 Jun 2023 12:54:36 +0000 Subject: [PATCH 2/2] black --- trl/trainer/ppo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index e4d7c6a97b..ea9dd6207e 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -1045,9 +1045,9 @@ def loss( entropy = masked_mean(entropy_from_logits(logits), mask) # calculate approx_kl http://joschu.net/blog/kl-approx.html approxkl_k1 = masked_mean(-logratio, mask) - approxkl_k2 = 0.5 * masked_mean(logratio ** 2, mask) + approxkl_k2 = 0.5 * masked_mean(logratio**2, mask) approxkl_k3 = masked_mean((ratio - 1) - logratio, mask) - # by default use `k3`, the estimator which has lowest variance and is unbiased + # by default use `k3`, the estimator which has lowest variance and is unbiased approxkl = approxkl_k3 return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)