NLP(一百)大模型数学能力测评

本文将介绍基于LLaMA-Factory对Yi-1.5-34B模型在自建数学数据集上进行微调,并对该模型在GSM8K和MATH数据集上进行测评,全面展示大模型数学能力测评方法和结果。

引言

在文章NLP(九十七)大模型数学解题能力的初步探索中,笔者介绍了对大模型数学解题能力的初步探索,并演示了在微调大约500个样本后,大模型已经初步具备了数学解题的能力。

在文章NLP(九十九)大模型的数学能力微调及测评中,笔者介绍了如何构建数学解题方面数据集(共2402条数据),以及对Qwen1.5-32B模型进行微调,并进行该微调模型的数学能力测评。

在本文中,笔者已将训练数据集扩充至3480条,关于数据集的构建,在文章NLP(九十九)大模型的数学能力微调及测评中已经略有提及。该数据集主要基于POT和COT的范式进行构建,后续将专门写文章来进行详细介绍。训练数据集已经开源至Github,访问网址为:https://github.com/percent4/llm_math_solver/blob/main/data_systhesis/data/train_data.json .

对该数据集使用LLaMA-Factory框架,对零一万物最新开源的Yi-1.5-34B模型进行微调,并将该微调模型导出为Yi-1.5-34B-math模型,使用下面的命令进行模型推理服务(单卡即可胜任模型推理):

1
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 nohup llamafactory-cli api examples/inference/yi15.yaml &

其中yi15.yaml文件配置如下:

1
2
3
4
model_name_or_path: /models/Yi-1.5-34B-math
template: yi
infer_backend: vllm
vllm_enforce_eager: true

接下来,我们将会对GSM8KMATH数据集进行测评,并详细介绍测评的流程及结果。

GSM8K测评

GSM8K数据集的测试集共1319条,我们使用上述的模型推理服务,对这些样本进行预测,得到预测文件。

GSM8K测试集的测评脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import re
import json
import subprocess
from rich.progress import track
from openai import OpenAI
import logging
from retry import retry
from random import choices

logging.basicConfig(level = logging.INFO, format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

os.environ["OPENAI_BASE_URL"] = "http://localhost:8000/v1"
os.environ["OPENAI_API_KEY"] = "0"
client = OpenAI()

execution_desc = ["运行以上代码,输出会是: ",
"现在将上面的代码复制到Python环境中运行,运行结果为:",
"执行上述Python代码,运行结果将是:",
"上面的Python代码执行结果为:",
"运行上述代码,我们可以得到题目要求的答案。输出结果将是:"]


@retry(exceptions=Exception, tries=3, delay=2)
def question_answer(query):
messages = [{"role": "system", "content": "你是一个数学解题大师,请解决下面的数学题,给出思考过程,必要时需要给出解题过程中的Python代码。正确答案的数值用\\boxed{}包围起来,最终的答案以因此开头,不要讲多余的废话。"}]
messages.append({"role": "user", "content": f"题目:{query}"})
result = client.chat.completions.create(messages=messages,
model="Yi-1.5-34B-Chat-math",
temperature=0.0,
stream=True)
reply_message = ""
for chunk in result:
if hasattr(chunk, "choices") and chunk.choices[0].delta.content:
reply_message += chunk.choices[0].delta.content

# find python code and execute the code
if '```python' in reply_message and '\n```' in reply_message:
messages.append({"role": "assistant", "content": '```'.join(reply_message.split('```')[:-1]) + '```'})
python_code_string = re.findall(r'```python\n(.*?)\n```', reply_message, re.S)[0]
python_file_path = 'temp.py'
with open(python_file_path, 'w') as f:
f.write(python_code_string)
python_code_run = subprocess.run(['python3', python_file_path], stdout=subprocess.PIPE)
if python_code_run.returncode:
raise RuntimeError("生成的Python代码无法运行!")
python_code_execution = python_code_run.stdout.decode('utf-8')
os.remove(python_file_path)
code_reply_str = choices(execution_desc, k=1)[0]
code_reply = f"\n{code_reply_str}```{python_code_execution.strip()}```\n"
reply_message += code_reply
messages.append({"role": "user", "content": code_reply})
result = client.chat.completions.create(messages=messages,
model="Yi-1.5-34B-Chat-math",
temperature=0.0,
stream=True)

final_reply = ""
for chunk in result:
if hasattr(chunk, "choices") and chunk.choices[0].delta.content:
reply_message += chunk.choices[0].delta.content
final_reply += chunk.choices[0].delta.content
return final_reply
else:
return reply_message


with open('gsm8k_test.jsonl', 'r') as f:
content = f.readlines()

total_cnt = 0
correct_cnt = 0
for line in track(content):
question, answer = json.loads(line.strip())['question'], json.loads(line.strip())['answer']
true_answer_number = answer.split('####')[-1].strip().replace(',', '')
try:
pred_answer = question_answer(question)
except Exception:
pred_answer = 'ERROR'
if re.findall(r'\\boxed\{.+?}', pred_answer) and re.findall('\d+', re.findall(r'\\boxed\{.+?}', pred_answer)[-1].replace(',', '')):
pred_answer_number = re.findall('\d+', re.findall(r'\\boxed\{.+?}', pred_answer)[-1].replace(',', ''))[0]
else:
pred_answer_number = ''
total_cnt += 1
logger.info("*" * 50)
logger.info('--- {} {} {}'.format(true_answer_number, pred_answer_number, repr(pred_answer)))
if true_answer_number == pred_answer_number:
correct_cnt += 1

with open('eval_result_yi_15_34b.json', 'a', encoding='utf-8') as f:
f.write(json.dumps({"is_correct": true_answer_number == pred_answer_number, "question": question, "answer": answer, "pred_answer": pred_answer}, ensure_ascii=False)+"\n")

logger.info('--- {} {} {}'.format(total_cnt, correct_cnt, correct_cnt/total_cnt))

对于测试样本,微调模型会预测解题过程,如果预测结果中包含Python代码,则使用正则表达式提取出被python\n(...)\n所包围的Python代码字符串,保存为Python脚本,使用subprocess模块执行Python脚本,如果代码执行报错或者超时(有可能为死循环),则会重试预测解题过程,这样最多尝试3次。

GSM8K数据集中,测试数据的正确答案在 #### 后面,而且一般为整数。而我们微调后的模型,在最后生成的回答中,将正确答案用包围起来,这样我们可以直接对比两者的结果,从而给出初步准确率。

但部分测试样本的最终答案并没有使用包围起来或者有报错(预测结果为ERROR),因此我们需要人工再次确认这些样本。笔者写了一个gradio构建的web页面,用于人工确认,Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import json
import gradio as gr


def read_samples():
with open("eval_result_yi_15_34b.json", "r") as f:
data = f.readlines()

content = []
cnt = 0
for sample in data:
sample_dict = json.loads(sample.strip())
if '\\boxed' not in sample_dict['pred_answer']:
cnt += 1
content.append([cnt, sample_dict['question'],
sample_dict['answer'].split('####')[-1].strip(),
sample_dict['pred_answer'],
0])
return content


def get_human_eval(df):
# get model evaluation
with open("eval_result_yi_15_34b.json", "r") as f:
data = f.readlines()

model_true_cnt = 0
for sample in data:
sample_dict = json.loads(sample.strip())
if '\\boxed' in sample_dict['pred_answer'] and sample_dict['is_correct']:
model_true_cnt += 1
# get human evaluation
human_true_cnt = 0
for i, row in df.iterrows():
if row['Human Evaluation']:
human_true_cnt += 1
return (f"Update {human_true_cnt} samples with human evaluation, \n"
f"Total Accuracy: {model_true_cnt + human_true_cnt}/{len(data)} = {(model_true_cnt + human_true_cnt)/len(data)}")


with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
table = gr.DataFrame(label='Table',
value=read_samples(),
headers=['No.', 'Question', 'Answer', 'Prediction', 'Human Evaluation'],
interactive=True,
wrap=True)
with gr.Row():
output = gr.Textbox(label='Human Evaluation')
submit = gr.Button("Search")

submit.click(fn=get_human_eval,
inputs=table,
outputs=output)

demo.launch()

人工确认页面如下图:

GSM8K测评——人工确认页面

经过人工再次确认后,我们就能得到微调模型在GSM8K测试集上的准确率了。

笔者按照这种方式测评方式对GSM8K测试集进行了两次测评,准确率均为 83.47%,这其中预测出错的样本共8个。

MATH测评

模型预测

MATH数据集的测试集共5000条,我们使用上述的模型推理服务,对这些样本进行预测,得到预测文件,推理过程较为漫长,在笔者单卡A100上需要大概1半天的时间。

预测脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import re
import json
import subprocess
from rich.progress import track
from openai import OpenAI
import logging
from retry import retry
from random import choices

logging.basicConfig(level = logging.INFO, format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

os.environ["OPENAI_BASE_URL"] = "http://localhost:8000/v1"
os.environ["OPENAI_API_KEY"] = "0"
client = OpenAI()

execution_desc = ["运行以上代码,输出会是: ",
"现在将上面的代码复制到Python环境中运行,运行结果为:",
"执行上述Python代码,运行结果将是:",
"上面的Python代码执行结果为:",
"运行上述代码,我们可以得到题目要求的答案。输出结果将是:"]

@retry(exceptions=Exception, tries=3, delay=2)
def question_answer(query):
messages = [{"role": "system", "content": "你是一个数学解题大师,请解决下面的数学题,给出思考过程,必要时需要给出解题过程中的Python代码。正确答案的数值用\\boxed{}包围起来,最终的答案以因此开头,不要讲多余的废话。"}]
messages.append({"role": "user", "content": f"题目:{query}"})
result = client.chat.completions.create(messages=messages,
model="Yi-1.5-34B-math",
temperature=0.2,
stream=True)
reply_message = ""
for chunk in result:
if hasattr(chunk, "choices") and chunk.choices[0].delta.content:
reply_message += chunk.choices[0].delta.content

# find python code and execute the code
if '```python' in reply_message and '\n```' in reply_message:
messages.append({"role": "assistant", "content": reply_message})
python_code_string = re.findall(r'```python\n(.*?)\n```', reply_message, re.S)[0]
python_file_path = 'temp.py'
with open(python_file_path, 'w') as f:
f.write(python_code_string)
python_code_run = subprocess.run(['python3', python_file_path], stdout=subprocess.PIPE, timeout=10)
if python_code_run.returncode:
raise RuntimeError("生成的Python代码无法运行!")
python_code_execution = python_code_run.stdout.decode('utf-8')
os.remove(python_file_path)
code_reply_str = choices(execution_desc, k=1)[0]
code_reply = f"\n{code_reply_str}```{python_code_execution.strip()}```\n"
reply_message += code_reply
messages.append({"role": "user", "content": code_reply})
result = client.chat.completions.create(messages=messages,
model="Yi-1.5-34B-math",
temperature=0.2,
stream=True)

final_reply = ""
for chunk in result:
if hasattr(chunk, "choices") and chunk.choices[0].delta.content:
reply_message += chunk.choices[0].delta.content
final_reply += chunk.choices[0].delta.content
return final_reply
else:
return reply_message

with open('math_test.jsonl', 'r') as f:
content = f.readlines()

samples = []
i = 1
for line in track(content):
data = json.loads(line.strip())
question, answer = data['problem'], data['solution']
try:
pred_answer = question_answer(question)
except:
pred_answer = "ERROR"
data.update({"predict_answer": pred_answer})
logger.info("*" * 50)
logger.info('--- {} true: {}'.format(i, repr(answer)))
logger.info('--- {} pred: {}'.format(i, repr(pred_answer)))
i += 1
with open('math_eval_result.json', 'a', encoding='utf-8') as f:
f.write(json.dumps(data, ensure_ascii=False)+"\n")

可以看到,该预测脚本与上述的GSM8K测评大致相同,只是缺少了准确率统计步骤。这是因为,MATH测试集的最终答案虽然用包围起来,但表达形式多样,且有些较为复杂,包含根式,分数,含pi表达式,多项式,区间等等。当然,微调模型的最终答案也用包围起来。

因此,如何判定预测答案与标准答案是否一致,这是一件困难的事情。

答案一致判定

在Github项目 hendrycks/math 中,提供了用于判断两个最终答案是否相等的代码脚本 math_equivalence.py ,网址为:https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py . 笔者对其稍加改造,引入分数与小数是否相等的判断(误差为10^-6),代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# @file: math_equivalence.py

import re
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string


def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string


def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string


def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string


def _strip_string(string):
# linebreaks
string = string.replace("\n", "")
# print(string)

# remove inverse spaces
string = string.replace("\\!", "")
# print(string)

# replace \\ with \
string = string.replace("\\\\", "\\")
# print(string)

# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# print(string)

# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# print(string)

# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")

# remove dollar signs
string = string.replace("\\$", "")

# remove units (on the right)
string = _remove_right_units(string)

# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")

# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# remove .0, .00, .000 in float number
if re.match(r'\d+\.0+$', string):
string = string.split('.')[0]

# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]

# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)

# remove spaces
string = string.replace(" ", "")

# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)

# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"

# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)

return string


def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
# 分数与小数的转换,并比较是否正确
if 'frac' in str1 and re.match(r'\d+\.?\d*$', str2):
result = re.findall(r'frac{(.*?)}{(.*?)}', str1)
nominator, denominator = result[0][0], result[0][1]
true_value = float(nominator) / float(denominator)
pred_value = float(str2)
return abs(true_value - pred_value) < 1e-6
if 'frac' in str2 and re.match(r'\d+\.?\d*$', str1):
result = re.findall(r'frac{(.*?)}{(.*?)}', str2)
nominator, denominator = result[0][0], result[0][1]
true_value = float(nominator) / float(denominator)
pred_value = float(str1)
return abs(true_value - pred_value) < 1e-6

try:
ss1 = _strip_string(str1)
ss2 = _strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except:
return str1 == str2

对上述判定代码进行单元测试,测试脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from .math_equivalence import is_equiv
import unittest


class TestIsEquiv(unittest.TestCase):

def test_fractions(self):
test_in = "\\tfrac{1}{2} + \\frac1{72}"
test_out = "\\\\frac{1}{2} + 2/3"
self.assertFalse(is_equiv(test_in, test_out))

def test_order(self):
test_in = "10, 4, -2"
test_out = "4, 10, -2"
self.assertFalse(is_equiv(test_in, test_out))

def test_order2(self):
test_in = "10, 4, 2"
test_out = "4, 12, 2"
self.assertFalse(is_equiv(test_in, test_out))

def test_dfrac(self):
test_in = "\\tfrac{1}{2} +\\! \\frac1{72}"
test_out = "\\\\dfrac{1}{2} +\\frac{1}{72}"
self.assertTrue(is_equiv(test_in, test_out))

def test_units(self):
test_in = "10\\text{ units}"
test_out = "10 "
self.assertTrue(is_equiv(test_in, test_out))

def test_units2(self):
test_in = "10\\text{ units}"
test_out = "100 "
self.assertFalse(is_equiv(test_in, test_out))

def test_dollars(self):
test_in = "10"
test_out = "\\$10"
self.assertTrue(is_equiv(test_in, test_out))

def test_parentheses(self):
test_in = "\\left(x-2\\right)\\left(x+2\\right)"
test_out = "(x-2)(x+2)"
self.assertTrue(is_equiv(test_in, test_out))

def test_decimals(self):
test_in = "0.1, 4, 2"
test_out = "4, .1, 2"
self.assertFalse(is_equiv(test_in, test_out))

def test_decimals2(self):
test_in = "0.1"
test_out = ".1"
self.assertTrue(is_equiv(test_in, test_out))

def test_percentage(self):
test_in = "10\\%"
test_out = "10"
self.assertTrue(is_equiv(test_in, test_out))

def test_sqrt(self):
test_in = "10\\sqrt{3} + \\sqrt4"
test_out = "10\\sqrt3 + \\sqrt{4}"
self.assertTrue(is_equiv(test_in, test_out))

def test_frac(self):
test_in = "\\frac34i"
test_out = "\\frac{3}{4}i"
self.assertTrue(is_equiv(test_in, test_out))

def test_tfrac(self):
test_in = "\\tfrac83"
test_out = "\\frac{8}{3}"
self.assertTrue(is_equiv(test_in, test_out))

def test_expression(self):
test_in = "5x - 7y + 11z + 4 = 0"
test_out = "x + y - z + 2 = 0"
self.assertFalse(is_equiv(test_in, test_out))

def test_half(self):
test_in = "1/2"
test_out = "\\frac{1}{2}"
self.assertTrue(is_equiv(test_in, test_out))

def test_frac_float(self):
test_in = "\\frac{1}{83}"
test_out = "0.012048192771084338"
self.assertTrue(is_equiv(test_in, test_out))

def test_frac_float_2(self):
test_in = "\\frac{1}{6}"
test_out = "0.166666666667"
self.assertTrue(is_equiv(test_in, test_out))

def test_frac_float_3(self):
test_in = "\\frac{8}{45}"
test_out = "0.17777777778"
self.assertTrue(is_equiv(test_in, test_out))

def test_frac_float_4(self):
test_in = "0.015"
test_out = "\\frac{3}{200}"
self.assertTrue(is_equiv(test_in, test_out))

def test_two_float(self):
test_in = "8"
test_out = "8.000000"
self.assertTrue(is_equiv(test_in, test_out))

def test_case_1(self):
test_in = "4x+5"
test_out = "4x + 5"
self.assertTrue(is_equiv(test_in, test_out))

使用命令python3 -m unittest MATH/math_equivalence_test.py -v运行上述单元测试,一共22个测试case,均测试通过。

初步准确率

利用上述答案一致判定方法,我们对微调模型预测后的文件进行初步的准确率统计,Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding: utf-8 -*-
# @file: math_eval_update.py
# 对模型评估后的结果加入是否正确的标记
import json

from math_equivalence import is_equiv


def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None

i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1

if right_brace_idx == None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]

return retval


def remove_boxed(s):
left = "\\boxed{"
try:
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
except:
return None


if __name__ == '__main__':
with open("math_eval_result.json", "r", encoding="utf-8") as f:
data = f.readlines()

correct_cnt = 0
content = []
for i, line in enumerate(data):
is_correct = False
sample = json.loads(line.strip())
true_answer, pred_answer = sample["solution"], sample["predict_answer"]
try:
true_answer_str = remove_boxed(last_boxed_only_string(true_answer))
pred_answer_str = remove_boxed(last_boxed_only_string(pred_answer))
if pred_answer_str is not None and is_equiv(true_answer_str, pred_answer_str):
correct_cnt += 1
is_correct = True
print(i, true_answer_str, pred_answer_str, correct_cnt, i + 1, correct_cnt/(i+1))
except Exception as e:
print(e)
sample.update({"is_correct": is_correct})
content.append(sample)

with open("math_eval_result_update.json", "w", encoding="utf-8") as f:
for _ in content:
f.write(json.dumps(_, ensure_ascii=False) + "\n")

运行上述代码,得到初步准确率为47.84%

人工确认

当然,和GSM8K数据集一样,我们需要对两边答案不一致的预测样本进行人工确认,因为上述的数学答案一致判定并不能覆盖所有的场景。

人工确认的Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
# @file: human_eval_server.py
# 对模型评估后的结果进行人工评估,使用gradio实现
import json
import gradio as gr

from math_eval_update import last_boxed_only_string, remove_boxed


def read_samples():
with open("math_eval_result_update.json", "r") as f:
data = f.readlines()

content = []
for i, sample in enumerate(data):
sample_dict = json.loads(sample.strip())
question, true_answer, pred_answer = sample_dict["problem"], sample_dict["solution"], sample_dict["predict_answer"]
try:
true_answer_str = remove_boxed(last_boxed_only_string(true_answer))
pred_answer_str = remove_boxed(last_boxed_only_string(pred_answer))
if not sample_dict["is_correct"] and pred_answer_str:
content.append([i, true_answer_str, pred_answer_str, 0])
except Exception as e:
# content.append([i, '', '', 0])
continue
return content


def get_human_eval(df):
# get model evaluation
with open("math_eval_result_update.json", "r") as f:
data = f.readlines()

model_true_cnt = 0
for sample in data:
sample_dict = json.loads(sample.strip())
if sample_dict["is_correct"]:
model_true_cnt += 1
# get human evaluation
human_true_cnt = 0
for i, row in df.iterrows():
if row['Human Evaluation']:
human_true_cnt += 1
# save human evaluation to json file
final_result = [json.loads(line.strip()) for line in data]
for i, row in df.iterrows():
if row['Human Evaluation']:
final_result[row['No.']]["is_correct"] = True
with open("math_eval_result_final.json", "w") as g:
for _ in final_result:
g.write(json.dumps(_, ensure_ascii=False) + '\n')

return (f"Update {human_true_cnt} samples with human evaluation, \n"
f"Total Accuracy: {model_true_cnt + human_true_cnt}/{len(data)} = {(model_true_cnt + human_true_cnt)/len(data)}")


with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
table = gr.DataFrame(label='Table',
value=read_samples(),
headers=['No.', 'True_Answer_Number', 'Pred_Answer_Number', 'Human Evaluation'],
interactive=True,
wrap=True
)
with gr.Row():
output = gr.Textbox(label='Human Evaluation')
submit = gr.Button("Search")

submit.click(fn=get_human_eval,
inputs=table,
outputs=output)

demo.launch()

经人工确认,最终微调后的Yi-1.5-34B-math模型在MATH测试集上的准确率修正为52.76%,其中包含预测错误样本共89个。

MATH测试集的准确率方面,Yi-1.5-34B-math模型超过原生Yi-1.5-34B模型11.76个百分点,超过微调Qwen1.5-32B-math模型9.18个百分点,且非常接近初版GPT-4模型的准确率(52.90%)。

MATH数据集 Leaderboard

指标汇总

不同模型经过微调的数学能力测评表如下:

数据集 GSM8K MATH
QWen1.5-32B 79.68% 43.58%
Yi-1.5-34B 83.47% 52.76%

注意:QWen1.5-32B 和 Yi-1.5-34B 微调时的训练集大小不一致,分别为2402和3408,这一点差异需要指出,其余参数均一样。后续将对更多模型在一样的训练集上进行公平测试。

MATH数据集上的测评结果是更值得分析的,指标可视化脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: utf-8 -*-
# @file: eval_visualization.py
import json
import plotly.graph_objects as go
from collections import defaultdict
from random import shuffle
from operator import itemgetter

# 读取数据
with open("math_eval_result_final.json", "r", encoding="utf-8") as f:
data = f.readlines()

type_dict = defaultdict(int)
level_dict = defaultdict(int)
for line in data:
sample = json.loads(line.strip())
if sample['is_correct']:
type_dict[sample['type']] += 1
level_dict[sample['level']] += 1

# 绘制类型的饼图
fig1 = go.Figure(data=go.Pie(labels=list(type_dict.keys()), values=list(type_dict.values())))
fig1.update_layout(
title="Type Distribution of Correct Answers in MATH",
font=dict(size=20)
)
# fig1.show()
# 绘制Level的饼图
fig2 = go.Figure(data=go.Pie(labels=list(level_dict.keys()), values=list(level_dict.values())))
fig2.update_layout(
title="Level Distribution of Correct Answers in MATH",
font=dict(size=20)
)
# fig2.show()

# 获取每个类型的正确率
type_cnt_dict = defaultdict(int)
level_cnt_dict = defaultdict(int)
for line in data:
sample = json.loads(line.strip())
type_cnt_dict[sample['type'] + '_total'] += 1
level_cnt_dict[sample['level'] + '_total'] += 1
if sample['is_correct']:
type_cnt_dict[sample['type'] + '_correct'] += 1
level_cnt_dict[sample['level'] + '_correct'] += 1

type_correct_ratio = {key: type_cnt_dict[f'{key}_correct']/type_cnt_dict[f'{key}_total'] for key in type_dict.keys()}
level_correct_ratio = {key: level_cnt_dict[f'{key}_correct']/level_cnt_dict[f'{key}_total'] for key in level_dict.keys()}
sorted_type_correct_ratio = {k: round(v, 4) for k, v in sorted(type_correct_ratio.items(), key=itemgetter(1), reverse=True)}
sorted_level_correct_ratio = {k: round(v, 4) for k, v in sorted(level_correct_ratio.items(), key=itemgetter(1), reverse=True)}

# 绘制类型的柱状图
colors = ['red', 'blue', 'green', 'purple', 'orange', 'pink', 'brown']
shuffle(colors)
fig3 = go.Figure(data=[go.Bar(x=list(sorted_type_correct_ratio.keys()),
y=list(sorted_type_correct_ratio.values()),
text=list(sorted_type_correct_ratio.values()),
textposition='auto',
marker=dict(color=colors[:len(type_dict.keys())]),
textfont=dict(size=20)
)])
fig3.update_layout(
title="Type Correct Ratio of MATH",
xaxis_title="Type",
yaxis_title="Correct Ratio",
legend_title="Type",
font=dict(size=20)
)
# fig3.show()
# 绘制Level的柱状图
fig4 = go.Figure(data=[go.Bar(x=list(sorted_level_correct_ratio.keys()),
y=list(sorted_level_correct_ratio.values()),
text=list(sorted_level_correct_ratio.values()),
textposition='auto',
marker=dict(color=colors[:len(type_dict.keys())]),
textfont=dict(size=20)
)])
fig4.update_layout(
title="Level Correct Ratio of MATH",
xaxis_title="Level",
yaxis_title="Correct Ratio",
legend_title="Level",
font=dict(size=20)
)
fig4.show()
  • 不同Level回答正确占总的回答正确的占比

  • 不同Type回答正确占总的回答正确的的占比

  • 不同Type回答正确率条形图

  • 不同Level回答正确率条形图

总结

本文是大模型数学能力系列文章中的第三篇,旨在介绍如何基于LLaMA-Factory对Yi-1.5-34B模型在自建数学数据集上进行微调,并对该模型在GSM8K和MATH数据集上进行测评,全面展示大模型数学能力测评方法和结果。

本文对应的测评Python代码已开源至Github,网址为:https://github.com/percent4/llm_math_solver .

后续笔者将会介绍如何构建大模型数学解题能力的训练集。

参考文献

  1. llm-leaderboard: https://www.vellum.ai/llm-leaderboard
  2. NLP(九十七)大模型数学解题能力的初步探索
  3. NLP(九十九)大模型的数学能力微调及测评
  4. MathEval 测评数据集: https://matheval.ai/dataset/
  5. 零一万物官网: https://www.lingyiwanwu.com/
  6. hendrycks/math: https://github.com/hendrycks/math

NLP(一百)大模型数学能力测评
https://percent4.github.io/NLP(一百)大模型数学能力测评/
作者
Jclian91
发布于
2024年6月19日
许可协议