NLP(一百一十二)微调LLM,解锁推理能力——GRPO算法如何提升人物关系分类

本文将介绍如何使用最新流行的GRPO强化学习算法,对LLM进行微调,以提升其在人物关系分类任务上的推理能力。

简介

在两年前的文章NLP(六十三)使用Baichuan-7b模型微调人物关系分类任务中,笔者介绍了如何使用Baichuan-7B模型微调(SFT)人物关系分类任务,并比BERT时代取得了进步。

如今已步入DeepSeek R1时代,GRPO等强化学习算法正引领新一轮技术潮流。特别是DeepSeek R1 Zero模型的横空出世,证明了纯强化学习(RL)训练同样能赋予模型强大的推理能力,颠覆了OpenAI传统的大模型训练三阶段范式(SFT -> RM -> RLHF)。这一创新正是DeepSeek的独特之处,使其在近几个月的LLM竞赛中占据领先地位,引领行业发展方向。

本文将会最新流行的GRPO强化学习算法,对LLM进行微调,以提升其在人物关系分类任务上的推理能力。

数据集

人物关系分类指的是对文本中的两个人物,在特定的关系列表中,判断他们之间的人物关系。以样本亲戚1837年6月20日,威廉四世辞世,他的侄女维多利亚即位。为例,其中亲戚为人物关系,威廉四世为实体1,维多利亚为实体2。

笔者自己利用业余时间标注的样本数据有3881条,分布如下图:

人物关系分布图

对上述数据集进行划分,训练集与测试集的比例为8:2,其中训练集3105条,测试集776条。

笔者已将上述数据集上传至Hugging Face,网址为:https://huggingface.co/datasets/jclian91/people_relation_classification .

奖励函数

奖励函数(Reward function)是GRPO训练过程中很关键的因素,一个好的奖励能使得模型往既定的能力上演进,增强模型能力,而一个坏的奖励可能使得该模型的表现更为糟糕。因此,奖励函数需要根据任务来精心设计。

笔者在此次训练过程中,训练了两个奖励函数,分别为format_reward_funclabel_reward_func函数。

对于format_reward_func函数,其要求是模型的回复要符合......格式,其中推理过程用...标记,最终的答案用...标记,最终答案是人物关系的标签。如果格式正确且答案中的人物关系是给定的人物关系列表,则得分为1,如果格式正确但答案不在给定任务关系列表中,则得分为0.5,否则得分为0。实现函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def format_reward_func(completions, label, **kwargs):
"""
Format: <think>...</think>\n<answer>...</answer>
Args:
completions (list[str]): Generated outputs
label (list[str]): Expected answers

Returns:
list[float]: Reward scores
"""
rewards = []

for completion, gt in zip(completions, label):
try:
# add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
completion = "<think>" + completion
if random.random() < 0.1: # 10% chance to write samples into a file
os.makedirs("completion_samples", exist_ok=True)
log_file = os.path.join("completion_samples", "completion_samples.txt")
with open(log_file, "a") as f:
f.write(f"\n\n==============\n")
f.write(completion)

# Check if the format is correct
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
match = re.search(regex, completion, re.DOTALL)
# Check if the answer is the predict labels
answer_label = re.findall(r'<answer>(.*?)</answer>', completion)
predict_labels = ['不知道', '夫妻', '父母', '兄弟姐妹', '上下级', '师生', '好友', '同学', '合作', '同一个人', '情侣', '祖孙', '同门', '亲戚']
# if the format is not correct, reward is 0
if match is None or len(match.groups()) != 2:
rewards.append(0.0)
else:
if answer_label and answer_label[0] in predict_labels:
rewards.append(1.0)
else:
rewards.append(0.5)
except Exception as err: # noqa
rewards.append(0.0)
return rewards

以下是上述奖励函数的单元测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from src.train.run_r1_grpo import format_reward_func


def test_format_reward_func_case_1():
# right format with right predict label
completion = ["哈哈哈</think>\n<answer>不知道</answer>"]
target = ['不知道']
reward_list = format_reward_func(completion, target)
assert reward_list == [1.0]


def test_format_reward_func_case_2():
# right format with wrong predict label
completion = ["哈哈哈</think>\n<answer>不确定</answer>"]
target = ['不知道']
reward_list = format_reward_func(completion, target)
assert reward_list == [0.5]


def test_format_reward_func_case_3():
# wrong format
completion = ["哈哈哈</think>\n不知道</answer>"]
target = ['不知道']
reward_list = format_reward_func(completion, target)
assert reward_list == [0.0]


def test_format_reward_func_case_4():
# right format with right predict label
completion = ["哈哈哈</think>\n<think>...</think>\n<answer>不知道</answer>"]
target = ['不知道']
reward_list = format_reward_func(completion, target)
assert reward_list == [0.0]

对于label_reward_func函数,其要求是答案中的人物关系与真实标签一致。如果两者一致,则得分为1,否则为0。该函数实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def label_reward_func(completions, label, text, people1, people2, **kwargs):
"""
Evaluates completions based on:
- whether the answer in completions matches the true_label

Args:
completions (list[str]): Generated outputs
label: Expected answers

Returns:
list[float]: Reward scores
"""
rewards = []
for completion, gt in zip(completions, label):
try:
# add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
completion = "<think>" + completion
# Check if the format is correct
match = re.search(r"<answer>(.*?)<\/answer>", completion)
if match is None:
rewards.append(0.0)
continue
# Extract the "answer" part from the completion
answer_label = re.findall(r'<answer>(.*?)</answer>', completion)
if answer_label and answer_label[0] == gt:
rewards.append(1.0)
if random.random() < 0.10: # 10% chance to write fully successful samples into a file
os.makedirs("completion_samples", exist_ok=True)
log_file = os.path.join("completion_samples", "success_completion_samples.txt")
with open(log_file, "a") as f:
f.write(f"\n\n==============\n")
f.write(f"文本:{text}\n人物1:{people1}\n人物2:{people2}\n")
f.write(completion)
else:
rewards.append(0.0)
except Exception as err: # noqa
rewards.append(0.0)
return rewards

模型训练

Hugging Face开源的trl模块已支持GRPOTrainer,因此使用该模块来进行强化微调(Reinforce Fine Tuning)。所需Python第三方模块如下:

1
2
3
4
5
6
7
8
9
datasets==3.1.0
pandas==2.2.3
Requests==2.32.3
rich==13.9.4
scikit_learn==1.6.1
torch==2.5.1
transformers==4.48.1
trl==0.14.0
gradio==5.20.0

我们以Qwen2.5-7B-Instruct为基座模型进行微调,其对话模板生成的prompt如下:

"prompt": "<|im_start|>system,判断这两个人物之间的关系,人物关系只能是['不知道', '夫妻', '父母', '兄弟姐妹', '上下级', '师生', '好友', '同学', '合作', '同一个人', '情侣', '祖孙', '同门', '亲戚']的一个。整体格式为......,推理过程用...标记,最终的答案用...标记,最终答案是人物关系的标签。<|im_end|><|im_start|>user:王玉宝的孙子王方轶大学毕业后,也成了通化广播电视台的记者,而孙女王梓怡在填报高考志愿时,毫不犹豫地报考了辽宁一所传媒院校的新闻专业。:王玉宝:王梓怡<|im_end|><|im_start|>assistant让我一步一步来思考解决。"

注意,, 等标签并不在Qwen2.5-7B-Instruct的token列表中,而我们希望该模型能学会这种回复模式,使得其具备推理能力。

我们使用deepspeed在在4张A800-SXM4-80GB显卡进行训练,命令如下:

1
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --main_process_port 29501 --num_processes 3 --config_file ./deepspeed_zero3.yaml ./run_relation_grpo.py --config ./grpo-qwen-2.5-7b-deepseek-r1-relation.yaml

其中run_relation_grpo.py脚本为主要的训练脚本,训练逻辑在此完成,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import logging
import os
from dataclasses import dataclass
from datetime import datetime
import logging
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import random
import re
import torch
from transformers.trainer_utils import get_last_checkpoint
from transformers import AutoTokenizer
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig, TrlParser


########################
# Custom dataclasses
########################
@dataclass
class ScriptArguments:
dataset_id_or_path: str = "jclian91/people_relation_classification"
dataset_splits: str = "train"
tokenizer_name_or_path: str = None


########################
# Setup logging
########################
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(handler)

########################
# Helper functions
########################


def format_reward_func(completions, label, **kwargs):
"""
Format: <think>...</think>\n<answer>...</answer>
Args:
completions (list[str]): Generated outputs
label (list[str]): Expected answers

Returns:
list[float]: Reward scores
"""
rewards = []

for completion, gt in zip(completions, label):
try:
# add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
completion = "<think>" + completion
if random.random() < 0.1: # 10% chance to write samples into a file
os.makedirs("completion_samples", exist_ok=True)
log_file = os.path.join("completion_samples", "completion_samples.txt")
with open(log_file, "a") as f:
f.write(f"\n\n==============\n")
f.write(completion)

# Check if the format is correct
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
match = re.search(regex, completion, re.DOTALL)
# Check if the answer is the predict labels
answer_label = re.findall(r'<answer>(.*?)</answer>', completion)
predict_labels = ['不知道', '夫妻', '父母', '兄弟姐妹', '上下级', '师生', '好友', '同学', '合作', '同一个人', '情侣', '祖孙', '同门', '亲戚']
# if the format is not correct, reward is 0
if match is None or len(match.groups()) != 2:
rewards.append(0.0)
else:
if answer_label and answer_label[0] in predict_labels:
rewards.append(1.0)
else:
rewards.append(0.5)
except Exception as err: # noqa
rewards.append(0.0)
return rewards


def label_reward_func(completions, label, text, people1, people2, **kwargs):
"""
Evaluates completions based on:
- whether the answer in completions matches the true_label

Args:
completions (list[str]): Generated outputs
label: Expected answers

Returns:
list[float]: Reward scores
"""
rewards = []
for completion, gt in zip(completions, label):
try:
# add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
completion = "<think>" + completion
# Check if the format is correct
match = re.search(r"<answer>(.*?)<\/answer>", completion)
if match is None:
rewards.append(0.0)
continue
# Extract the "answer" part from the completion
answer_label = re.findall(r'<answer>(.*?)</answer>', completion)
if answer_label and answer_label[0] == gt:
rewards.append(1.0)
if random.random() < 0.10: # 10% chance to write fully successful samples into a file
os.makedirs("completion_samples", exist_ok=True)
log_file = os.path.join("completion_samples", "success_completion_samples.txt")
with open(log_file, "a") as f:
f.write(f"\n\n==============\n")
f.write(f"文本:{text}\n人物1:{people1}\n人物2:{people2}\n")
f.write(completion)
else:
rewards.append(0.0)
except Exception as err: # noqa
rewards.append(0.0)
return rewards


def get_checkpoint(training_args: GRPOConfig):
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
return last_checkpoint


def grpo_function(
model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig
):
#########################
# Log parameters
#########################
logger.info(f"Model parameters {model_args}")
logger.info(f"Training/evaluation parameters {training_args}")

################
# Load tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
(
script_args.tokenizer_name_or_path
if script_args.tokenizer_name_or_path
else model_args.model_name_or_path
),
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

###############
# Load datasets
###############
# Load dataset from Hugging Face Hub
dataset = load_dataset(script_args.dataset_id_or_path, split=script_args.dataset_splits)
dataset = dataset.shuffle(seed=42)

#####################
# Prepare and format dataset
#####################

# generate r1 prompt with a prefix for the model to already start with the thinking process
def generate_r1_prompt(text, people1, people2):
r1_prefix = [{
"role": "system",
"content": "给定下面的文本和文本中的两个人物,仅根据文本内容来判断这两个人物之间的关系,人物关系只能是['不知道', '夫妻', '父母', '兄弟姐妹', '上下级', '师生', '好友', '同学', '合作', '同一个人', '情侣', '祖孙', '同门', '亲戚']的一个。"
"整体格式为<think>...</think>\n<answer>...</answer>,推理过程用<think>...</think>标记,最终的答案用<answer>...</answer>标记,最终答案是人物关系的标签。"
},
{
"role": "user",
"content": f"文本:{text}\n人物1:{people1}\n人物2:{people2}"
},
{
"role": "assistant",
"content": "让我一步一步来思考解决。\n<think>"
}]
return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True)}

# convert our dataset to the r1 prompt
dataset = dataset.map(lambda x: generate_r1_prompt(x["text"], x["people1"], x["people2"]))
print("first dataset:", dataset[0])

# split the dataset into train and test
train_test_split = dataset.train_test_split(test_size=0.1)

train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

#########################
# Instantiate GRPO trainer
#########################

trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=[format_reward_func, label_reward_func],
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
peft_config=get_peft_config(model_args),
)

###############
# Training loop
###############
# Check for last checkpoint
last_checkpoint = get_checkpoint(training_args)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")

# Train the model
logger.info(
f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***'
)
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
# Log and save metrics
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

logger.info("*** Training complete ***")

##################################
# Save model and create model card
##################################

logger.info("*** Save model ***")
trainer.model.config.use_cache = True
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
training_args.distributed_state.wait_for_everyone() # wait for all processes to load

tokenizer.save_pretrained(training_args.output_dir)
logger.info(f"Tokenizer saved to {training_args.output_dir}")

# Save everything else on main process
if trainer.accelerator.is_main_process:
trainer.create_model_card({"tags": ["rl", "grpo"]})
# push to hub if needed
if training_args.push_to_hub is True:
logger.info("Pushing to hub...")
trainer.push_to_hub()

logger.info("*** Training complete! ***")


def main():
parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
model_args, script_args, training_args = parser.parse_args_and_config()

# Run the main training loop
grpo_function(model_args, script_args, training_args)


if __name__ == "__main__":
main()

训练参数配置如下(grpo-qwen-2.5-7b-deepseek-r1-relation.yaml):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Model arguments
model_name_or_path: /mnt/models/Qwen2.5-7B-Instruct
# model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
bf16: true
tf32: true
output_dir: runs/qwen-2.5-7b-r1-people-relation

# Dataset arguments
dataset_id_or_path: jclian91/people_relation_classification

# Lora Arguments
# No LoRA is used here

# Training arguments
max_steps: 500
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 5.0e-7 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
lr_scheduler_type: cosine
warmup_ratio: 0.03
# GRPO specific parameters
beta: 0.001 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
max_prompt_length: 256
max_completion_length: 1024
num_generations: 8
use_vllm: true
# vllm_device: "cuda:3"
vllm_gpu_memory_utilization: 0.5

# Logging arguments
logging_strategy: steps
logging_steps: 2
report_to:
- tensorboard
save_strategy: "steps"
save_steps: 25
seed: 42

# Hugging Face Hub
push_to_hub: false
# hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir
hub_strategy: every_save

deepspeed配置文件如下(deepspeed_zero3.yaml):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

笔者在4张A800-SXM4-80GB显卡上,训练时间为4.5小时。训练结果在tensorboard上查看如下:

整体情况
回复长度
format_reward_func
label_reward_func

可以看到,随着训练步数的增加,模型的平均回复长度从一开始的90多稳定在60多,而format_reward_func渐渐稳定在1左右,而label_reward_func稳定在0.85-0.90之间。由此可见,模型对于格式是严格遵守的,而标签的真实性还有待增强。

模型评估

我们使用vllm框架来部署上述训练后的模型,命令如下:

1
CUDA_VISIBLE_DEVICES=4,5,6,7 python -m vllm.entrypoints.openai.api_server --model /mnt/qwen-2.5-7b-r1-people-relation --served-model-name qwen-2.5-7b-r1-cls

对人物关系数据集(776条样本)进行评估,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-
# @file: evaluate_r1_grpo.py
# @time: 2025/3/4 21:39
import re
import json
import requests
from rich.progress import track
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score


def predict(text, people1, people2, label):
url = "http://0.0.0.0:8000/v1/chat/completions"
headers = {'Content-Type': 'application/json'}

json_data = {
'model': "qwen-2.5-7b-r1-cls",
'messages': [{
"role": "system",
"content": "给定下面的文本和文本中的两个人物,仅根据文本内容来判断这两个人物之间的关系,人物关系只能是['不知道', '夫妻', '父母', '兄弟姐妹', '上下级', '师生', '好友', '同学', '合作', '同一个人', '情侣', '祖孙', '同门', '亲戚']的一个。"
"整体格式为<think>...</think>\n<answer>...</answer>,推理过程用<think>...</think>标记,最终的答案用<answer>...</answer>标记,最终答案是人物关系的标签。"
},
{
"role": "user",
"content": f"文本:{text}\n人物1:{people1}\n人物2:{people2}"
},
{
"role": "assistant",
"content": "让我一步一步来思考解决。\n<think>"
}
],
'temperature': 0.0
}

response = requests.post(url, headers=headers, json=json_data)
print(repr(response.json()["choices"][0]["message"]["content"]))
result = response.json()["choices"][0]["message"]["content"]
# 使用正则表达式提取出<answer>...</answer>中的内容
answer = re.findall(r"<answer>(.*?)</answer>", result, re.S)
if answer:
with open("predict.jsonl", "a") as f:
f.write(json.dumps(
{"text": text,
"people1": people1,
"people2": people2,
"label": label,
"predict_label": answer[0],
"predict_content": "<think>" + result
}, ensure_ascii=False
) + "\n")
return answer[0]
else:
return ""


if __name__ == '__main__':
df = pd.read_csv('test.csv')
true_labels, pred_labels = [], []
# 使用rich模块加入进度条
for i, record in track(df.iterrows(), description="[red]Predicting...", total=len(df)):
true_label = record["label"]
true_labels.append(true_label)
pred_label = predict(record["text"], record["people1"], record["people2"], record["label"])
pred_labels.append(pred_label)
print(f"Processing Row {i+1}: true_label: {true_label}, pred_label: {pred_label}")

print(classification_report(true_labels, pred_labels, digits=4))

评估结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
        precision    recall  f1-score   support

上下级 0.7742 0.7742 0.7742 31
不知道 0.8945 0.8517 0.8725 209
亲戚 0.8421 0.6667 0.7442 24
兄弟姐妹 0.8378 0.9118 0.8732 34
合作 0.8182 0.9153 0.8640 59
同一个人 1.0000 0.9487 0.9737 39
同学 0.9130 0.8750 0.8936 24
同门 0.9524 0.7692 0.8511 26
夫妻 0.9367 0.9367 0.9367 79
好友 0.7742 0.8000 0.7869 30
师生 0.8684 0.8919 0.8800 37
情侣 0.9032 0.9032 0.9032 31
父母 0.9197 0.9844 0.9509 128
祖孙 0.8889 0.9600 0.9231 25

accuracy 0.8892 776
macro avg 0.8802 0.8706 0.8734 776
weighted avg 0.8903 0.8892 0.8884 776

人物关系任务是笔者自创的人物,已有四五年,这期间数据集固定不动,而模型训练算法早已一日千里,在此汇总如下:

模型方法 基座模型 F1值 说明
BERT向量提取+BiGRU+Attention BiGRU+Attention 78.97% BERT模型作为特征提取处理
BERT cls finetuning BERT 82.69% 当作文本分类任务处理
R-BERT chinese-roberta-wwm-ext 85.35% BERT时代的关系分类模型代表
R-BERT chinese-roberta-wwm-ext-large 87.22% BERT时代的关系分类模型代表
QLora Baichuan-7B 88.25% max_seq_length=256,epoch=5, …
QLora Baichuan-7B 89.15% max_seq_length=256,epoch=10, …
GRPO with LLM Qwen2.5-7B-Instruct 88.84% 基于LLM的纯强化方案, steps=500

需要对上述结果做些说明,GRPO在测试集上的F1值达到了88.84%,虽然比Baichuan-7B差一点,但并不能说明基于GRPO的强化微调比SFT效果差,两者的方案并不相同,在文本长度,是否具有推理能力上是不一样的。基于GRPO的强化微调的模型有推理能力,在给出关系分类的结果时,会给出推理过程(即分类理由),而SFT后的模型只输出了关系分类结果,没有中间的推理过程。

为了说明上述观点,笔者挑选了2个预测结果正确的例子作为说明。

  • 正确例子1

文本:周海婴:鲁迅与许广平之子(1929—2011),无线电专家。 人物1:周海婴 人物2:许广平 真实标签:父母 预测:根据文本内容,“周海婴:鲁迅与许广平之子”,可以确定周海婴是许广平的儿子,因此他们之间是父母关系。父母

  • 正确例子2

文本:就在十年前,金池和曹格都还只是涂惠源老师门下的弟子,与他们一同在音乐路上追梦的同门还有谭维维、曾一鸣和今年在“中国好声音”亮相的魏雪漫。 人物1:曹格 人物2:谭维维 真实标签:同门 预测:根据文本内容,金池、曹格 、谭维维、曾一鸣和魏雪漫都是涂惠源老师门下的弟子,他们是在音乐路上一起追梦的同门。因此,曹格和谭维维之间是同门关系。同门

但也存在着预测错误的例子,列举如下:

  • 错误例子1

文本:梁启超的弟子中有胡适,徐志摩和王庚。 人物1:徐志摩 人物2:王庚 真实标签:同门 预测:根据文本内容,梁启超的弟子包括徐志摩和王庚。这表明徐志摩和王庚都是梁启超的弟子,但文本中没有提到他们之间的任何关系。因此 ,我们只能确定他们有共同的老师,但没有足够的信息来确定他们之间的具体关系。不知道

  • 错误例子2

文本:徐佛苏简介:在日本与当时留学生的同盟会,日知会等交往.曾担任或黄兴先生的秘书; 人物1:徐佛苏 人物2:黄兴 真实标签:合作 预测:根据文本内容,徐佛苏曾担任黄兴先生的秘书。在历史上,秘书通常是为上司或重要人物服 务的职位,这表明徐佛苏和黄兴之间存在上下级关系或至少是较为正式的工作关系。上下级

从第2个例子看,真实标签反而存在错误,导致模型预测出错。因此,实际上,该训练后模型的推理能力(或评估指标)应比评估结果更强。

模型预测

在新的数据上进行测试,验证模型的泛化能力:

新的例子1
新的例子2
新的例子3
新的例子4
新的例子5

从上述的例子中,我们可以得到一点启发:那就是模型在给出正确的人物关系标签时,也给出了推理过程(think阶段),而这个推理过程我们是可以拿来作为解释的,这是不是另外一种模型的可解释性呢?

以往,我们对于大模型(或者神经网络)的可解释性存在困惑,觉得它的可解释性不行,但有了模型的推理过程后,这一点应该不难做到,虽然这并不是严格意义上的模型的可解释性

总结

笔者已将上述代码及结果在Notion上发布,网址为:https://local-dugout-3c9.notion.site/People-Relation-Classification-with-GRPO-training-1ac8b5de853080aaa7a6c2ab85f09dde

文末列出了不少参考文献,大多数都是DeepSeek R1模型的复现和实践,也是最近一段时间里在R1复现任务中比较有代表性的工作,值得细细品味。

参考文献

  1. NLP(六十三)使用Baichuan-7b模型微调人物关系分类任务: https://zhuanlan.zhihu.com/p/655360024
  2. Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial: https://www.philschmid.de/mini-deepseek-r1
  3. Open-R1: a fully open reproduction of DeepSeek-R1: https://huggingface.co/blog/open-r1
  4. TinyZero: https://github.com/Jiayi-Pan/TinyZero
  5. Logic-RL: https://github.com/Unakar/Logic-RL
  6. DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning
    https://arxiv.org/abs/2501.12948
  7. Train your own R1 reasoning model with Unsloth (GRPO): https://unsloth.ai/blog/r1-reasoning
  8. Training with GRPOTrainer
    https://www.stephendiehl.com/posts/grpotrainer/
  9. GRPO Training Script for Qwen Model on GSM8K Dataset
    https://github.com/kossisoroyce/train_grpo.py
  10. LLaMA Factory:微调DeepSeek-R1-Distill-Qwen-7B模型实现新闻标题分类器: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b

欢迎关注我的公众号NLP奇幻之旅,原创技术文章第一时间推送。

欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。


NLP(一百一十二)微调LLM,解锁推理能力——GRPO算法如何提升人物关系分类
https://percent4.github.io/NLP(一百一十二)微调LLM,解锁推理能力——GRPO算法如何提升人物关系分类/
作者
Jclian91
发布于
2025年4月27日
许可协议