forked from allenai/open-instruct
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrpo_vllm_thread_ray_gtrl.py
1736 lines (1595 loc) Β· 77.7 KB
/
grpo_vllm_thread_ray_gtrl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF
# which has the following license:
# Copyright [yyyy] [name of copyright owner]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import logging
import os
import random
import shutil
import socket
import subprocess
import threading
import time
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from queue import Empty, Queue
from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple
os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA
import deepspeed
import numpy as np
import pandas as pd
import ray
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils
import torch.utils.data
from datasets import Dataset, DatasetDict
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from huggingface_hub import HfApi
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.queue import Queue as RayQueue
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rich.pretty import pprint
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
get_scheduler,
)
from transformers.integrations import HfDeepSpeedConfig
from vllm import SamplingParams
from open_instruct.dataset_processor import (
CHAT_TEMPLATES,
DATASET_SOURCE_KEY,
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
DatasetConfig,
SFTGroundTruthDatasetProcessor,
SimpleGenerateCollatorWithGroundTruth,
visualize_token,
)
from open_instruct.model_utils import (
ModelConfig,
apply_verifiable_reward,
disable_dropout_in_model,
exact_div,
first_true_indices,
forward,
get_reward,
print_rich_single_line_metrics,
print_rich_table,
push_folder_to_hub,
truncate_response,
)
from open_instruct.utils import (
ArgumentParserPlus,
BeakerRuntimeConfig,
combine_dataset,
get_wandb_tags,
is_beaker_job,
maybe_get_beaker_config,
maybe_use_ai2_hf_entity,
maybe_use_ai2_wandb_entity,
upload_metadata_to_hf,
)
from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group
api = HfApi()
INVALID_LOGPROB = 1.0
@dataclass
class Args:
# required dataset args
dataset_mixer: str = None
"""A dictionary of datasets (local or HF) to sample from."""
dataset_train_splits: List[str] = None
"""The dataset splits to use for training"""
dataset_eval_mixer: Optional[str] = None
"""A dictionary of datasets (local or HF) to sample from for evaluation"""
dataset_eval_splits: Optional[List[str]] = None
"""The dataset splits to use for evaluation"""
dataset_mixer_dict: Optional[dict] = None
"""The dataset mixer as a dictionary"""
dataset_eval_mixer_dict: Optional[dict] = None
"""The dataset eval mixer as a dictionary"""
# common args
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of this experiment"""
seed: int = 1
"""Seed of the experiment"""
run_name: Optional[str] = None
"""A unique name of this run"""
# optimizer args
eps: float = 1e-5
"""The epsilon value for the optimizer"""
learning_rate: float = 2e-5
"""The initial learning rate for AdamW optimizer."""
lr_scheduler_type: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "linear"
"""Which scheduler to use"""
warm_up_steps: int = 0
"""Number of warm up steps for the scheduler"""
warmup_ratio: float = 0.0
"""Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)"""
# various batch sizes
num_train_epochs: int = 1
"""Number of epochs to train"""
gradient_accumulation_steps: Optional[int] = None
"""The number of gradient accumulation steps"""
per_device_train_batch_size: Optional[int] = 1
"""The forward batch size per device (local_micro_batch_size)"""
per_device_eval_batch_size: Optional[int] = 1
"""The forward batch size per device for evaluation (local_micro_batch_size)"""
total_episodes: Optional[int] = 100000
"""The total number of episodes in the dataset"""
world_size: Optional[int] = None
"""The number of processes (GPUs) to use"""
micro_batch_size: Optional[int] = None
"""The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)"""
local_rollout_batch_size: int = 64
"""The number of rollout episodes per iteration per device"""
rollout_batch_size: Optional[int] = None
"""The number of rollout episodes per iteration"""
num_training_steps: Optional[int] = None
"""The number of training_steps to train"""
num_evals: int = 4
"""The number of evaluations to run throughout training"""
eval_freq: Optional[int] = None
"""The frequency of evaluation steps"""
local_dataloader_batch_size: Optional[int] = None
"""The batch size per GPU for the dataloader"""
save_freq: int = -1
"""How many train steps to save the model"""
# online settings
num_epochs: int = 4
"""the number of epochs to train"""
num_mini_batches: int = 1
"""Number of minibatches to split a batch into"""
local_mini_batch_size: int = 64
"""the mini batch size per GPU"""
mini_batch_size: Optional[int] = None
"""the mini batch size across GPUs"""
local_rollout_forward_batch_size: int = 64
"""per rank no grad forward pass in the rollout phase"""
reward_model_path: str = "EleutherAI/pythia-160m"
"""the path to the reward model"""
reward_model_revision: Optional[str] = None
"""the revision of the reward model"""
# generation config
response_length: int = 53
"""the length of the response"""
stop_token: Optional[Literal["eos", "period"]] = None
"""the stop token"""
stop_token_id: Optional[int] = None
"""the truncation token id"""
min_response_length: int = 0
"""stop only after this many tokens"""
temperature: float = 0.7
"""the sampling temperature"""
penalty_reward_value: float = -1.0
"""the reward value for responses that do not contain `stop_token_id`"""
non_stop_penalty: bool = False
"""whether to penalize responses that do not contain `stop_token_id`"""
number_samples_per_prompt: int = 1
"""the number of samples to generate per prompt, useful for easy-star"""
# online PPO specific args
beta: float = 0.05
"""the beta value of the RLHF objective (KL coefficient)"""
whiten_rewards: bool = False
"""whether to whiten the rewards"""
cliprange: float = 0.2
"""the clip range"""
gamma: float = 1
"""the discount factor"""
kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl3"
"""the KL estimator to use"""
apply_verifiable_reward: bool = False
"""whether to apply verifiable reward"""
reward_model_multiplier: float = 1.0
"""the reward model multiplier, for down/upscaling the reward model output"""
answer_extraction_model: str = None
# async setting
async_mode: bool = True
"""Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)"""
# ray
actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1])
"""number of gpus per node for actor"""
vllm_num_engines: int = 1
"""number of vLLM Engines, set to 0 to disable vLLM"""
vllm_tensor_parallel_size: int = 1
"""tensor parallel size of vLLM Engine for multi-GPU inference"""
vllm_enforce_eager: bool = False
"""whether to enforce eager mode for vLLM -- slow inference but needed for multi-node"""
vllm_sync_backend: str = "nccl"
"""DeepSpeed -> vLLM weight sync backend"""
enable_prefix_caching: bool = False
"""whether to enable prefix caching"""
deepspeed_stage: int = 0
"""the deepspeed stage"""
gather_whole_model: bool = True
"""whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)"""
# wandb and HF tracking configs
with_tracking: bool = False
"""If toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "open_instruct_internal"
"""The wandb's project name"""
wandb_entity: Optional[str] = None
"""The entity (team) of wandb's project"""
push_to_hub: bool = True
"""Whether to upload the saved model to huggingface"""
hf_entity: Optional[str] = None
"""The user or org name of the model repository from the Hugging Face Hub"""
hf_repo_id: Optional[str] = None
"""The id of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_revision: Optional[str] = None
"""The revision of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_url: Optional[str] = None
"""The url of the saved model in the Hugging Face Hub (will be autoset)"""
output_dir: Optional[str] = None
"""Where to save the model"""
checkpoint_output_dir: Optional[str] = None
"""Where to save the model checkpoints in case of preemption"""
# Ai2 specific settings
try_launch_beaker_eval_jobs: bool = True
"""Whether to launch beaker evaluation jobs after training"""
try_launch_beaker_eval_jobs_on_weka: bool = False
"""Whether to launch beaker evaluation jobs after training on weka"""
try_auto_save_to_beaker: bool = True
"""Whether to try to save the model to Beaker dataset `/output` after training"""
oe_eval_tasks: Optional[List[str]] = None
"""The beaker evaluation tasks to launch"""
hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals"
"""What dataset to upload the metadata to. If unset, don't upload metadata"""
def __post_init__(self):
assert self.number_samples_per_prompt > 1, "Number of samples per prompt must be greater than 1 for GRPO!"
self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer)
if self.dataset_eval_mixer is not None:
self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer)
def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]:
# if passed through cli: convert the dataset mixers to dictionaries
if isinstance(value, str):
return json.loads(value), value
# if passed through yaml: convert the dataset mixers to strings
elif isinstance(value, dict):
return value, json.dumps(value)
else:
raise ValueError("Input must be either a string or a dictionary")
def calculate_runtime_args(args: Args, model_config: ModelConfig):
"""calculate (in-place) runtime args such as the effective batch size, word size, etc."""
# accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
# args.world_size = accelerator.num_processes
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
args.gradient_accumulation_steps = exact_div(
args.local_mini_batch_size,
args.per_device_train_batch_size,
"`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`",
)
args.world_size = sum(args.actor_num_gpus_per_node)
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size)
args.mini_batch_size = int(args.local_mini_batch_size * args.world_size)
args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt)
args.eval_freq = max(1, args.num_training_steps // args.num_evals)
# PPO logic: do checks and set up dataloader batch size
if args.whiten_rewards:
assert (
args.local_mini_batch_size >= 8
), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
args.local_dataloader_batch_size = args.rollout_batch_size
if args.push_to_hub:
if args.hf_repo_id is None: # auto-generate one
args.hf_repo_id = "open_instruct_dev"
if args.hf_entity is None: # first try to use AI2 entity
args.hf_entity = maybe_use_ai2_hf_entity()
if args.hf_entity is None: # then try to use the user's entity
args.hf_entity = HfApi().whoami()["name"]
args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
if args.hf_repo_revision is None: # auto-generate one
args.hf_repo_revision = args.run_name
args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}"
if args.with_tracking:
if args.wandb_entity is None:
args.wandb_entity = maybe_use_ai2_wandb_entity()
def get_train_ds_config(
offload,
adam_offload=False,
stage=0,
bf16=True,
max_norm=1.0,
zpg=8,
grad_accum_dtype=None,
disable_trace_cache=True,
):
device = "cpu" if offload else "none"
zero_opt_dict = {
"stage": stage,
"offload_param": {"device": device},
"offload_optimizer": {
"device": "cpu" if adam_offload else "none",
"pin_memory": True,
},
"sub_group_size": "auto",
"stage3_max_live_parameters": "auto",
"stage3_max_reuse_distance": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_prefetch_bucket_size": "auto",
"reduce_bucket_size": "auto",
# # ZeRO++
# "zero_hpz_partition_size": zpg,
# "zero_quantized_weights": False,
# "zero_quantized_gradients": False,
}
if disable_trace_cache:
zero_opt_dict["stage3_prefetch_bucket_size"] = 0
zero_opt_dict["stage3_max_live_parameters"] = 0
zero_opt_dict["stage3_max_reuse_distance"] = 0
return {
"steps_per_print": 100,
"zero_optimization": zero_opt_dict,
"bf16": {
"enabled": bf16,
},
"gradient_clipping": max_norm,
"prescale_gradients": False,
"wall_clock_breakdown": False,
"data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"},
}
def get_eval_ds_config(
offload,
stage=0,
bf16=True,
):
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": "auto",
"offload_param": {
"device": "cpu" if offload else "none",
"pin_memory": True,
},
}
return {
"steps_per_print": 100,
"zero_optimization": zero_opt_dict,
"bf16": {
"enabled": bf16,
},
"prescale_gradients": False,
"wall_clock_breakdown": False,
}
def get_optimizer_grouped_parameters(
model,
weight_decay,
no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"],
):
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay": weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if (any(nd in n for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def _z3_params_to_fetch(param_list):
return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
else:
return (values * mask).sum() / mask.sum()
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError(
"The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
"try increase the `mini_batch_size` or `gradient_accumulation_steps`"
)
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def remove_padding(sequences, pad_token_id):
return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences]
class ShufflingIterator:
def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None):
self.data = data.copy()
self.batch_size = batch_size
self.index = 0
self.rng = np.random.default_rng(seed)
self.rng.shuffle(self.data)
# Ensure the effective dataset size is divisible by batch_size
self.effective_size = len(self.data) - (len(self.data) % batch_size)
def __iter__(self) -> Iterator[List[int]]:
return self
def __next__(self) -> List[int]:
if self.index >= self.effective_size:
self.index = 0
self.rng.shuffle(self.data)
end_index = self.index + self.batch_size
batch = self.data[self.index : end_index].tolist()
self.index = end_index
return batch
class RayProcess:
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
self.world_size = world_size
self.rank = rank
self.local_rank = local_rank
self.master_addr = master_addr if master_addr else self.get_current_node_ip()
self.master_port = master_port if master_port else self.get_free_port()
os.environ["MASTER_ADDR"] = self.master_addr
os.environ["MASTER_PORT"] = str(self.master_port)
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["RANK"] = str(self.rank)
# NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES
# environment variable for each actor, so always set device to 0
# os.environ["LOCAL_RANK"] = str(self._local_rank)
os.environ["LOCAL_RANK"] = "0"
random.seed(self.rank)
np.random.seed(self.rank)
torch.manual_seed(self.rank)
@staticmethod
def get_current_node_ip():
address = ray._private.services.get_node_ip_address()
# strip ipv6 address
return address.strip("[]")
@staticmethod
def get_free_port():
with socket.socket() as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
def get_master_addr_port(self):
return self.master_addr, self.master_port
def empty_cache(self) -> None:
torch.cuda.empty_cache()
@ray.remote(num_gpus=1)
class PolicyTrainerRayProcess(RayProcess):
def from_pretrained(
self, args: Args, model_config: ModelConfig, beaker_config: BeakerRuntimeConfig, wandb_url: str
):
self.args = args
self.model_config = model_config
self.beaker_config = beaker_config
self.wandb_url = wandb_url
torch.cuda.set_device(self.local_rank)
deepspeed.init_distributed()
ds_config = get_train_ds_config(
offload=False,
adam_offload=False,
stage=args.deepspeed_stage,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["train_batch_size"] = args.mini_batch_size
# Costa: MAGIC: it's actually needed to initialize this `dschf`, so
# https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration
# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
else:
dschf = None
print(f"{dschf=}")
self.original_tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, revision=model_config.model_revision
)
self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
disable_dropout_in_model(self.policy)
self.policy.gradient_checkpointing_enable()
# AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam
# AdamOptimizer = FusedAdam
# weight_decay = 0.0
# optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay)
# self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate)
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate)
num_training_steps = args.num_training_steps * args.num_train_epochs * args.num_epochs
warm_up_steps = args.warm_up_steps
if args.warmup_ratio >= 0.0:
warm_up_steps = int(num_training_steps * args.warmup_ratio)
scheduler = get_scheduler(
args.lr_scheduler_type,
optimizer=self.optimizer,
num_warmup_steps=warm_up_steps,
num_training_steps=num_training_steps,
)
print(ds_config)
self.model, self.optimizer, _, self.scheduler = deepspeed.initialize(
model=self.policy,
optimizer=self.optimizer,
config=ds_config,
lr_scheduler=scheduler,
dist_init_required=True,
)
self.model.train()
# reference model
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["train_batch_size"] = args.mini_batch_size
# Costa: MAGIC: it's actually needed to initialize this `dschf`, so
# https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration
# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
else:
dschf = None
print(f"{dschf=}")
self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
disable_dropout_in_model(self.ref_policy)
self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config)
self.ref_policy.eval()
# reward model
if args.reward_model_multiplier:
self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
args.reward_model_path,
revision=args.reward_model_revision,
num_labels=1,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
disable_dropout_in_model(self.reward_model)
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["train_batch_size"] = args.mini_batch_size
self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config)
self.reward_model.eval()
assert (
args.reward_model_multiplier or args.apply_verifiable_reward
), "Either `reward_model_multiplier` must be non-zero or `apply_verifiable_reward` must be True."
def get_vocab_size(self):
return self.policy.config.vocab_size
def forward(
self,
query_response: torch.LongTensor,
response: torch.LongTensor,
pad_token_id: int,
context_length: int,
temperature: float,
) -> torch.Tensor:
output = forward(self.model, query_response, pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= temperature + 1e-7
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
return logprob
def train(
self,
train_dataset: Dataset,
eval_dataset: Dataset,
tokenizer: PreTrainedTokenizer,
vllm_engines: List[ray.actor.ActorHandle],
metrics_queue: RayQueue,
data_collator: Callable,
):
torch.set_printoptions(precision=4, sci_mode=False)
args = self.args
accelerator = Namespace()
accelerator.process_index = self.rank
accelerator.num_processes = self.world_size
accelerator.is_main_process = self.rank == 0
torch.distributed.barrier()
if self.rank == 0:
master_address = ray._private.services.get_node_ip_address()
with socket.socket() as sock:
sock.bind(("", 0))
master_port = sock.getsockname()[1]
vllm_num_engines, vllm_tensor_parallel_size = (
args.vllm_num_engines,
args.vllm_tensor_parallel_size,
)
world_size = vllm_num_engines * vllm_tensor_parallel_size + 1
backend = args.vllm_sync_backend
# https://github.com/OpenRLHF/OpenRLHF/issues/313
# if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0":
# backend = "gloo"
# print(
# "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)"
# )
refs = [
engine.init_process_group.remote(
master_address,
master_port,
i * vllm_tensor_parallel_size + 1,
world_size,
"openrlhf",
backend=backend,
)
for i, engine in enumerate(vllm_engines)
]
self.model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=0,
group_name="openrlhf",
)
ray.get(refs)
torch.distributed.barrier()
def broadcast_to_vllm():
# avoid OOM
torch.cuda.empty_cache()
model = self.model.module
count, num_params = 0, len(list(model.named_parameters()))
refss = []
if args.gather_whole_model:
with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3):
for name, param in model.named_parameters():
count += 1 # empty_cache at last param
# Fire all vllm engines for broadcast
if torch.distributed.get_rank() == 0:
shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape
refs = [
engine.update_weight.remote(
name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
)
for engine in vllm_engines
]
refss.extend(refs)
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(param.data, 0, group=self.model_update_group)
else: # broadcast each parameter independently
for name, param in model.named_parameters():
count += 1
if torch.distributed.get_rank() == 0:
shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape
refs = [
engine.update_weight.remote(
name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
)
for engine in vllm_engines
]
refss.extend(refs)
with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3):
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(param.data, 0, group=self.model_update_group)
if torch.distributed.get_rank() == 0:
ray.get(refss)
# broadcast_to_vllm()
if args.stop_token:
if args.stop_token == "eos":
args.stop_token_id = tokenizer.eos_token_id
if args.stop_token == "period":
args.stop_token_id = tokenizer.encode(".")[0]
# data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id)
train_dataset_idxs = np.arange(len(train_dataset))
shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed)
# hack to left pad
def repeat_generator():
while True:
batch_idxs = next(shuffling_iter)
yield [train_dataset[i] for i in batch_idxs]
iter_dataloader = iter(repeat_generator())
generation_config = SamplingParams(
temperature=args.temperature,
top_p=1.0,
max_tokens=args.response_length,
include_stop_str_in_output=True,
n=args.number_samples_per_prompt,
)
# print("setup async queues")
param_prompt_Q = None
response_ids_Q = None
evaluation_Q = None
response_ids_Q = Queue(maxsize=1)
param_prompt_Q = Queue(maxsize=1)
evaluation_Q = Queue(maxsize=1)
num_eval_samples = 32
sample_evaluation_prompt_token_ids = None
if eval_dataset is not None:
sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY]
def vllm_generate(
generation_config: SamplingParams,
response_ids_Q: Queue,
param_prompt_Q: Queue,
num_training_steps: int,
sample_evaluation_prompt_token_ids: Optional[List[int]],
evaluation_Q: Queue,
eval_freq: int,
resume_training_step: int,
):
llm = vllm_engines[0]
for training_step in range(resume_training_step, num_training_steps + 1):
items = param_prompt_Q.get()
if items is None:
break
unwrapped_model, g_queries_list = items
# if unwrapped_model is not None:
generation_start_time = time.time()
outputs = ray.get(
llm.generate.remote(
sampling_params=generation_config, prompt_token_ids=g_queries_list, use_tqdm=False
)
)
response_ids = [list(out.token_ids) for output in outputs for out in output.outputs]
print(f"π₯π₯π₯ Generation time: {time.time() - generation_start_time:.2f} seconds")
response_ids_Q.put(response_ids)
if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0:
outputs = ray.get(
llm.generate.remote(
prompt_token_ids=sample_evaluation_prompt_token_ids,
sampling_params=generation_config,
use_tqdm=False,
)
)
# for evaluation, even if we have multiple outputs, we only look at one of them for simplicity
response_ids = [list(output.outputs[0].token_ids) for output in outputs]
evaluation_Q.put(response_ids)
resume_training_step = 1
if accelerator.is_main_process:
thread = threading.Thread(
target=vllm_generate,
args=(
generation_config,
response_ids_Q,
param_prompt_Q,
args.num_training_steps,
sample_evaluation_prompt_token_ids,
evaluation_Q,
args.eval_freq,
resume_training_step,
),
)
thread.start()
print("vllm generate thread starts")
# set up the metrics and initial states
device = torch.device(self.local_rank)
g_vllm_responses = torch.zeros(
(args.rollout_batch_size * args.number_samples_per_prompt, args.response_length),
device=device,
dtype=torch.long,
)
stats_shape = (
args.num_epochs,
args.num_mini_batches * args.number_samples_per_prompt,
args.gradient_accumulation_steps,
)
non_score_reward_sum_stats = torch.zeros(stats_shape, device=device)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
reward_mean = torch.zeros(stats_shape, device=device)
reward_std = torch.zeros(stats_shape, device=device)
entropy_stats = torch.zeros(stats_shape, device=device)
ratio_stats = torch.zeros(stats_shape, device=device)
local_metrics = torch.zeros((20,), device=device)
episode = args.rollout_batch_size * (resume_training_step - 1)
# training loop
start_time = time.time()
global_data = next(iter_dataloader)
data = data_collator(
global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size]
)
global_queries = data_collator(global_data)[
INPUT_IDS_PROMPT_KEY
].tolist() # can be simplified since we `remove_padding` later anyway
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
if accelerator.is_main_process:
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))
answer_extraction_model = None
answer_extraction_tokenizer = None
# for _ in range(1, resume_training_step): # we didn't store scheduler state
# scheduler.step()
for training_step in range(resume_training_step, args.num_training_steps + 1):
episode += args.rollout_batch_size * args.number_samples_per_prompt # each sample is an episode
queries = queries_next
ground_truths = ground_truths_next
datasets = datasets_next
if accelerator.is_main_process:
df = None
try:
evaluation_responses = evaluation_Q.get(timeout=0.01)
print("π₯π₯π₯ Evaluation responses received")
table = {}
table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids)
table["response"] = tokenizer.batch_decode(evaluation_responses)
table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]]
df = pd.DataFrame(table)
del table
except Empty:
print("π Evaluation responses not received")
# (optionally) evaluate the model
if args.async_mode:
if training_step != 1:
global_data = next(iter_dataloader)
data = data_collator(
global_data[
self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size
]
)
global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist()
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
start_time = time.time()
broadcast_to_vllm()
if accelerator.is_main_process:
print(
f"π₯π₯π₯ Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds"
)
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))
else:
if training_step != 1:
# NOTE: important: the indent here is different for sync mode
# we also set to use `queries = queries_next` immediately
global_data = next(iter_dataloader)
data = data_collator(
global_data[
self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size
]
)
global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist()
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
start_time = time.time()
broadcast_to_vllm()
if accelerator.is_main_process:
print(
f"π₯π₯π₯ Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds"
)
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))
queries = queries_next
ground_truths = ground_truths_next
datasets = datasets_next
torch.cuda.empty_cache()
# print('get reward stuff starts')
# if we generate multiple samples per prompt, we need to repeat the queries and ground truths