Skip to content

Commit

Permalink
[LLM] Support prefix tuning and lora for qwen2 (#8601)
Browse files Browse the repository at this point in the history
* add unittest for qwen2

* update for tie_word_embeddings

* update qwen2

* update for tokenizer set attr to null

* support prefix training

* fix llm unittest

* fix pipeline and sequence parallel

* add pretrain configs

* add lora, prefix tuning, sft config

* fix pp for lora and recompute
  • Loading branch information
DrownFish19 authored Jun 20, 2024
1 parent 6bfca91 commit da8b9ac
Show file tree
Hide file tree
Showing 22 changed files with 1,007 additions and 76 deletions.
32 changes: 32 additions & 0 deletions llm/qwen/lora_argument_qwen2_7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2_7b__lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps": 16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 2048,
"max_length": 4096,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
41 changes: 41 additions & 0 deletions llm/qwen/pretrain-qwen1.5_7b-tp2sd4_stage2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"model_name_or_path": "Qwen/Qwen1.5-7B",
"tokenizer_name_or_path": "Qwen/Qwen1.5-7B",
"input_dir": "./data",
"output_dir": "./checkpoints/qwen1.5_7b_pretrain_ckpts",
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 8,
"per_device_eval_batch_size": 2,
"tensor_parallel_degree": 2,
"pipeline_parallel_degree": 1,
"sharding_parallel_degree": 4,
"sharding": "stage2",
"virtual_pp_degree": 1,
"sequence_parallel": 0,
"use_flash_attention": true,
"use_fused_rms_norm": true,
"use_fused_rope": true,
"max_seq_length": 4096,
"learning_rate": 3e-05,
"min_learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"max_steps": 10000,
"save_steps": 5000,
"eval_steps": 1000,
"weight_decay": 0.01,
"bf16": true,
"fp16_opt_level": "O2",
"warmup_ratio": 0.01,
"max_grad_norm": 1.0,
"dataloader_num_workers": 1,
"continue_training": 1,
"do_train": true,
"do_eval": true,
"do_predict": true,
"disable_tqdm": true,
"recompute": true,
"distributed_dataloader": 1,
"recompute_granularity": "full",
"save_total_limit": 2
}
41 changes: 41 additions & 0 deletions llm/qwen/pretrain-qwen2_7b-tp2sd4_stage2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"tokenizer_name_or_path": "Qwen/Qwen2-7B",
"input_dir": "./data",
"output_dir": "./checkpoints/qwen2_7b_pretrain_ckpts",
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 8,
"per_device_eval_batch_size": 2,
"tensor_parallel_degree": 2,
"pipeline_parallel_degree": 1,
"sharding_parallel_degree": 4,
"sharding": "stage2",
"virtual_pp_degree": 1,
"sequence_parallel": 0,
"use_flash_attention": true,
"use_fused_rms_norm": true,
"use_fused_rope": true,
"max_seq_length": 4096,
"learning_rate": 3e-05,
"min_learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"max_steps": 10000,
"save_steps": 5000,
"eval_steps": 1000,
"weight_decay": 0.01,
"bf16": true,
"fp16_opt_level": "O2",
"warmup_ratio": 0.01,
"max_grad_norm": 1.0,
"dataloader_num_workers": 1,
"continue_training": 1,
"do_train": true,
"do_eval": true,
"do_predict": true,
"disable_tqdm": true,
"recompute": false,
"distributed_dataloader": 1,
"recompute_granularity": "full",
"save_total_limit": 2
}
33 changes: 33 additions & 0 deletions llm/qwen/pt_argument_qwen2_7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2_7b_pt_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps": 16,
"num_train_epochs": 3,
"learning_rate": 3e-02,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 2048,
"max_length": 4096,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"prefix_tuning": true,
"zero_padding": false,
"use_flash_attention": false
}

31 changes: 31 additions & 0 deletions llm/qwen/sft_argument_qwen2_7b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"model_name_or_path": "Qwen/Qwen2-7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2-7b_sft_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 2048,
"max_length": 4096,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 4,
"pipeline_parallel_degree": 1,
"zero_padding": false,
"use_flash_attention": false
}
21 changes: 20 additions & 1 deletion llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ChatGLMv2Tokenizer,
LlamaForCausalLMPipe,
PretrainedConfig,
Qwen2ForCausalLMPipe,
)
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
from paddlenlp.utils.log import logger
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_prefix_tuning_params(model):
num_hidden_layers = model.config.num_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = chatglm_postprocess_past_key_value
multi_query_group_num = model.config.multi_query_group_num
multi_query_group_num = model.config.multi_query_group_num # num_key_value_heads
elif model.base_model_prefix == "bloom":
from paddlenlp.peft.prefix import bloom_postprocess_past_key_value

Expand All @@ -92,6 +93,14 @@ def get_prefix_tuning_params(model):
hidden_size = model.config.hidden_size
postprocess_past_key_value = qwen_postprocess_past_key_value
multi_query_group_num = None
elif model.base_model_prefix == "qwen2":
from paddlenlp.peft.prefix import qwen_postprocess_past_key_value

num_attention_heads = model.config.num_attention_heads
num_hidden_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = qwen_postprocess_past_key_value
multi_query_group_num = model.config.num_key_value_heads # num_key_value_heads
else:
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}. ")
return dict(
Expand Down Expand Up @@ -150,6 +159,16 @@ def get_lora_target_modules(model):
".*mlp.w2.*",
".*mlp.c_proj.*",
]
elif model.base_model_prefix == "qwen2" or isinstance(model, Qwen2ForCausalLMPipe):
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*gate_proj.*",
".*down_proj.*",
".*up_proj.*",
]
elif model.base_model_prefix == "mixtral":
target_modules = [
".*q_proj.*",
Expand Down
11 changes: 11 additions & 0 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
)
pass

# Note:
# 1. PipelineLayer will create parameters for each layer and
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
# synchronize the shared parameters.
# However, when state dict only contains the one piece of shared parameters, the shared parameters
# will be different from the original shared parameters.

if isinstance(model, PipelineLayer):
model._synchronize_shared_weights()

if paddle.in_dynamic_mode():
return model

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/qwen2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@

from .configuration import *
from .modeling import *
from .modeling_pp import *
from .tokenizer import *
3 changes: 3 additions & 0 deletions paddlenlp/transformers/qwen2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def __init__(
self.eos_token_id = eos_token_id

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading

0 comments on commit da8b9ac

Please sign in to comment.