Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KL Estimator #423

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/test_no_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __getitem__(self, idx):
"ppo/loss/total",
"ppo/policy/entropy",
"ppo/policy/approxkl",
"ppo/policy/policykl",
"ppo/policy/approxkl_k1",
"ppo/policy/approxkl_k2",
"ppo/policy/approxkl_k3",
"ppo/policy/clipfrac",
"ppo/policy/advantages",
"ppo/policy/advantages_mean",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
"ppo/loss/total",
"ppo/policy/entropy",
"ppo/policy/approxkl",
"ppo/policy/policykl",
"ppo/policy/approxkl_k1",
"ppo/policy/approxkl_k2",
"ppo/policy/approxkl_k3",
"ppo/policy/clipfrac",
"ppo/policy/advantages",
"ppo/policy/advantages_mean",
Expand Down
67 changes: 38 additions & 29 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,8 @@ def collator(data):
all_stats.append(train_stats)

if self.config.early_stopping:
policykl = train_stats["policy/policykl"]
early_stop = self._early_stop(policykl)
approxkl = train_stats["policy/approxkl"]
early_stop = self._early_stop(approxkl)
if early_stop:
break

Expand Down Expand Up @@ -734,7 +734,7 @@ def collator(data):

return stats

def _early_stop(self, policykl):
def _early_stop(self, approxkl):
r"""
Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
the optimization step is skipped.
Expand All @@ -751,7 +751,7 @@ def _early_stop(self, policykl):
if not self.config.early_stopping:
return early_stop

if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
if not self.is_distributed and approxkl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
elif self.is_distributed:
Expand All @@ -760,11 +760,11 @@ def _early_stop(self, policykl):
# Wait for all processes to finish
dist.barrier()

# all gather the policykl
dist.all_reduce(policykl, dist.ReduceOp.SUM)
policykl /= self.accelerator.num_processes
# all gather the approxkl
dist.all_reduce(approxkl, dist.ReduceOp.SUM)
approxkl /= self.accelerator.num_processes

if policykl > 1.5 * self.config.target_kl:
if approxkl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
return early_stop
Expand Down Expand Up @@ -1011,16 +1011,17 @@ def loss(
values = values * mask
rewards = rewards * mask

for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
with torch.no_grad():
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

returns = advantages + values
advantages = masked_whiten(advantages, mask)
advantages = advantages.detach()
returns = advantages + values
advantages = masked_whiten(advantages, mask)
advantages = advantages.detach()

vpredclipped = clip_by_value(
vpreds, values - self.config.cliprange_value, values + self.config.cliprange_value
Expand All @@ -1031,30 +1032,38 @@ def loss(
vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).double(), mask)

ratio = torch.exp(logprobs - old_logprobs)
logratio = logprobs - old_logprobs
ratio = torch.exp(logratio)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)

pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).double(), mask)

loss = pg_loss + self.config.vf_coef * vf_loss

entropy = masked_mean(entropy_from_logits(logits), mask)
approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
policykl = masked_mean(old_logprobs - logprobs, mask)
with torch.no_grad():
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).double(), mask)
entropy = masked_mean(entropy_from_logits(logits), mask)
# calculate approx_kl http://joschu.net/blog/kl-approx.html
approxkl_k1 = masked_mean(-logratio, mask)
approxkl_k2 = 0.5 * masked_mean(logratio**2, mask)
approxkl_k3 = masked_mean((ratio - 1) - logratio, mask)
# by default use `k3`, the estimator which has lowest variance and is unbiased
approxkl = approxkl_k3

return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)

stats = dict(
loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
policy=dict(
entropy=entropy.detach(),
approxkl=approxkl.detach(),
policykl=policykl.detach(),
clipfrac=pg_clipfrac.detach(),
advantages=advantages.detach(),
advantages_mean=masked_mean(advantages, mask).detach(),
entropy=entropy,
approxkl=approxkl,
approxkl_k1=approxkl_k1,
approxkl_k2=approxkl_k2,
approxkl_k3=approxkl_k3,
clipfrac=pg_clipfrac,
advantages=advantages,
advantages_mean=masked_mean(advantages, mask),
ratio=ratio.detach(),
),
returns=dict(mean=return_mean.detach(), var=return_var.detach()),
Expand Down