MedicalGPT:第一阶段:PT/SFT/DPO 单卡跑通指南(持续更新ing)

仓库链接:https://github.com/shibing624/MedicalGPT?tab=readme-ov-file

为什么选择 MedicalGPT

环境搭建成本低,上手速度快

  • 按照 README 安装即可,整体过程相对顺畅。

  • MedicalGPT 的强化学习部分基于 TRL ,依赖更轻、安装更省心;对比某些更“工程化/更复杂”的训练框架,上手门槛更低。

覆盖大模型训练全链路,路径完整

  • 一套仓库能把关键阶段串起来:PT / SFT / RL / RAG

  • RL 侧支持的算法较丰富,便于在同一套工程框架下做对比实验与迭代;

第一阶段:跑通 MedicalGPT

准备工作:环境与底座模型

在开始跑全链路之前,需要准备好环境和一个 小参数量的底座模型(方便在单卡上快速跑通,不追求效果,只验证流程)。

1. 安装依赖

1
2
3
4
git clone --depth 1 https://github.com/shibing624/MedicalGPT.git
cd MedicalGPT
ls
pip install -r requirements.txt

2. 下载轻量底座 :建议先使用 Qwen/Qwen1.5-0.5B-ChatQwen/Qwen2-1.5B-Instruct。参数量小,16GB显存的单卡就能跑通完整 SFT 和 DPO。

PT+SFT+DPO全链路训练

训练步骤如下:

  1. 确认训练集
  2. 执行训练脚本

训练脚本的执行逻辑如下:

  1. 导入依赖包
  2. 设置参数
  3. 定义各函数并加载训练集
  4. 加载模型和tokenizer
  5. 开始训练并评估
  6. 查看训练结果

以下参数可以根据你的GPU实际情况修改,当前参数是根据单卡GPU(16GB显存)配置的

stage 1: 增量预训练(PT)

查看预训练数据:

1
ls ./data/pretrain/

执行预训练脚本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
python pretraining.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--train_file_dir ./data/pretrain \
--validation_file_dir ./data/pretrain \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 3 \
--do_train \
--do_eval \
--use_peft True \
--seed 42 \
--bf16 \
--max_train_samples 20000 \
--max_eval_samples 10 \
--num_train_epochs 1 \
--learning_rate 2e-4 \
--warmup_ratio 0.05 \
--weight_decay 0.01 \
--logging_strategy steps \
--logging_steps 10 \
--eval_steps 50 \
--eval_strategy steps \
--save_steps 50 \
--save_strategy steps \
--save_total_limit 3 \
--gradient_accumulation_steps 1 \
--preprocessing_num_workers 1 \
--block_size 128 \
--output_dir outputs-pt-v1 \
--overwrite_output_dir \
--ddp_timeout 30000 \
--logging_first_step True \
--target_modules all \
--lora_rank 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--torch_dtype bfloat16 \
--device_map auto \
--report_to tensorboard \
--ddp_find_unused_parameters False \
--gradient_checkpointing True

查看输出文件:

1
ls -lh outputs-pt-v1

模型训练结果:

  • 使用lora训练模型,则保存的lora权重是adapter_model.safetensors, lora配置文件是adapter_config.json,合并到base model的方法见merge_peft_adapter.py
  • 日志保存在output_dir/runs目录下,可以使用tensorboard查看,启动tensorboard方式如下:tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009

lora模型权重合并到base model,合并后的模型保存在--output_dir目录下,合并方法如下:

1
2
python merge_peft_adapter.py \
--base_model Qwen/Qwen2.5-0.5B --lora_model outputs-pt-v1 --output_dir merged-pt/

查看合并后的输出文件:

1
ls -lh merged-pt/
1
cat merged-pt/config.json

Stage 2: 有监督微调(SFT)

构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图,并注入领域知识

生成模型:使用的是Qwen/Qwen2.5-0.5B 或者 Stage1得到的预训练模型

数据集:SFT阶段使用的是使用的是Belle的1千条抽样数据,位于data/finetune文件夹

查看微调数据:

1
ls ./data/finetune

执行微调脚本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
python supervised_finetuning.py \
--model_name_or_path merged-pt \
--train_file_dir ./data/finetune \
--validation_file_dir ./data/finetune \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--do_train \
--do_eval \
--use_peft True \
--bf16 \
--max_train_samples 1000 \
--max_eval_samples 10 \
--num_train_epochs 1 \
--learning_rate 2e-5 \
--warmup_steps 50 \
--weight_decay 0.05 \
--logging_strategy steps \
--logging_steps 10 \
--eval_steps 50 \
--eval_strategy steps \
--save_steps 500 \
--save_strategy steps \
--save_total_limit 3 \
--gradient_accumulation_steps 1 \
--preprocessing_num_workers 1 \
--output_dir outputs-sft-v1 \
--ddp_timeout 30000 \
--logging_first_step True \
--target_modules all \
--lora_rank 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--torch_dtype bfloat16 \
--device_map auto \
--report_to tensorboard \
--ddp_find_unused_parameters False \
--gradient_checkpointing True

查看输出结果:

1
ls -lh outputs-sft-v1

模型训练结果:

  • 使用lora训练模型,则保存的lora权重是adapter_model.safetensors, lora配置文件是adapter_config.json,合并到base model的方法见merge_peft_adapter.py
  • 日志保存在output_dir/runs目录下,可以使用tensorboard查看,启动tensorboard方式如下:tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009

lora模型权重合并到base model,合并后的模型保存在--output_dir目录下,合并方法如下:

1
2
python merge_peft_adapter.py \
--base_model merged-pt --lora_model outputs-sft-v1 --output_dir ./merged-sft

查看合并后的目录

1
2
ls -lh merged-sft/
cat merged-sft/config.json

Stage 3: 直接偏好优化(DPO)

DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好

生成模型:使用的是Qwen/Qwen2.5-0.5B 或者 Stage2得到的SFT模型

数据集:DPO阶段使用的是医疗reward数据,抽样了500条,位于data/reward文件夹

查看数据集

1
ls ./data/reward/

执行DPO脚本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
python dpo_training.py \
--model_name_or_path ./merged-sft \
--template_name qwen \
--train_file_dir ./data/reward \
--validation_file_dir ./data/reward \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 1 \
--do_train \
--do_eval \
--use_peft True \
--max_train_samples 1000 \
--max_eval_samples 500 \
--max_steps 100 \
--eval_steps 10 \
--save_steps 50 \
--max_source_length 256 \
--max_target_length 256 \
--output_dir outputs-dpo-v1 \
--target_modules all \
--lora_rank 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--torch_dtype bfloat16 \
--bf16 True \
--fp16 False \
--device_map auto \
--report_to tensorboard \
--remove_unused_columns False \
--gradient_checkpointing True \
--cache_dir ./cache

注:

  • 当前安装的较新版本 trl 中,DPOConfig 类的初始化函数已经不再接受 max_prompt_length 这个参数,在dpo_training.py 中找到 max_prompt_length=xxx, 直接删掉(或者加 # 注释掉)
  • 新版本 Hugging Face 移除了 adamw_hf,在命令行中直接使用 adamw_torch 覆盖代码默认优化器。

查看输出文档

1
ls -lh outputs-dpo-v1

模型训练结果:

  • 使用lora训练模型,则保存的lora权重是adapter_model.safetensors, lora配置文件是adapter_config.json,合并到base model的方法见merge_peft_adapter.py
  • 日志保存在output_dir/runs目录下,可以使用tensorboard查看,启动tensorboard方式如下:tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009

lora模型权重合并到base model,合并后的模型保存在--output_dir目录下,合并方法如下:

1
2
python merge_peft_adapter.py \
--base_model merged-sft --lora_model outputs-dpo-v1 --output_dir merged-dpo/
1
2
3
ls -lh merged-dpo/

cat merged-dpo/config.json

至此一个完整的训练流程演示完成。

测试

1
2
3
python inference.py --base_model merged-dpo
# 或在shell中运行
# python inference.py --base_model merged-dpo --interactive

MedicalGPT-inference