Keras入门(七)使用Flask+Keras-bert构建模型预测服务

在文章NLP(三十四)使用keras-bert实现序列标注任务中,我们介绍了如何使用keras-bert模块,利用BERT中文预训练模型来实现序列标注任务的模型训练、模型评估和模型预测。其中,模型预测是通过加载生成的h5文件来实现的。

本文将会介绍如何使用Flask构建模型预测的HTTP服务。

我们遵循正常的思路,即先使用Keras加载保存后的h5模型文件,利用Flask对新输入的文本进行模型预测,最后给出预测结果。我们对人民日报命名实体实体数据集进行模型训练,采用文章NLP(三十四)使用keras-bert实现序列标注任务中的模型,训练后得到example_ner.h5文件,模型预测的HTTP服务脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# -*- coding: utf-8 -*-
import json
import traceback
import numpy as np
from keras.models import load_model
from keras_bert import get_custom_objects
from keras_contrib.layers import CRF
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_accuracy
from flask import Flask, request

from model_train import PreProcessInputData, id_label_dict


# 将BIO标签转化为方便阅读的json格式
def bio_to_json(string, tags):
item = {"string": string, "entities": []}
entity_name = ""
entity_start = 0
iCount = 0
entity_tag = ""

for c_idx in range(min(len(string), len(tags))):
c, tag = string[c_idx], tags[c_idx]
if c_idx < len(tags)-1:
tag_next = tags[c_idx+1]
else:
tag_next = ''

if tag[0] == 'B':
entity_tag = tag[2:]
entity_name = c
entity_start = iCount
if tag_next[2:] != entity_tag:
item["entities"].append({"word": c, "start": iCount, "end": iCount + 1, "type": tag[2:]})
elif tag[0] == "I":
if tag[2:] != tags[c_idx-1][2:] or tags[c_idx-1][2:] == 'O':
tags[c_idx] = 'O'
pass
else:
entity_name = entity_name + c
if tag_next[2:] != entity_tag:
item["entities"].append({"word": entity_name, "start": entity_start, "end": iCount + 1, "type": entity_tag})
entity_name = ''
iCount += 1
return item


app = Flask(__name__)


@app.route("/model/ner", methods=["GET", "POST"])
def get_geo():
return_result = {"code": 200, "message": "success", "data": []}
try:
text = request.get_json()["text"].replace(" ", "")
word_labels, seq_types = PreProcessInputData([text])

# 模型预测
predicted = ner_model.predict([word_labels, seq_types])
y = np.argmax(predicted[0], axis=1)
tag = [id_label_dict[_] for _ in y]

# 输出预测结果
result = bio_to_json(text, tag[1:-1])
return_result["data"] = result

except Exception:
return_result["code"] = 400
return_result["message"] = traceback.format_exc()

return json.dumps(return_result, ensure_ascii=False, indent=2)


if __name__ == '__main__':
# 加载训练好的模型
custom_objects = get_custom_objects()
for key, value in {'CRF': CRF, 'crf_loss': crf_loss, 'crf_accuracy': crf_accuracy}.items():
custom_objects[key] = value
ner_model = load_model("example_ner.h5", custom_objects=custom_objects)
# 启动HTTP服务
app.run(host="0.0.0.0", port=25000)

看上去上面的服务并没有什么问题,但当我们进行HTTP请求时,报错如下:

1
2
3
File "/home/jclian91/.conda/envs/py3-lmj/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3875, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("crf_1/cond/Merge:0", shape=(?, ?, 7), dtype=float32) is not an element of this graph.

上网搜资料,发现这种错误非常常见,其中一种解决方法如下:

导入模块:

1
2
import tensorflow as tf
from keras.backend import set_session

在加载模型(load_model)的代码前,加几行代码如下:

1
2
3
sess = tf.Session()
graph = tf.get_default_graph()
set_session(sess)

同时在HTTP服务的模型预测(ner_model.predict)前,加几行代码如下:

1
2
3
4
5
6
# 模型预测
global sess
global graph
with graph.as_default():
set_session(sess)
predicted = ner_model.predict([word_labels, seq_types])

这样再次启动模型预测HTTP脚本,可以发现模型预测的HTTP请求是正常的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
$ curl --location --request POST 'http://192.168.1.193:25000/model/ner' \
> --header 'Content-Type: application/json' \
> --data-raw '{
> "text": "美国卫生部长阿扎尔辞职 原因曝光"
> }'
{
"code": 200,
"message": "success",
"data": {
"string": "美国卫生部长阿扎尔辞职原因曝光",
"entities": [
{
"word": "美国卫生部",
"start": 0,
"end": 5,
"type": "ORG"
},
{
"word": "阿扎尔",
"start": 6,
"end": 9,
"type": "PER"
}
]
}
}

该脚本已上传至Github,网址为:https://github.com/percent4/keras_bert_sequence_labeling/blob/master/model_server.py

感谢阅读~

2021年1月16日于上海浦东

欢迎关注我的公众号NLP奇幻之旅,原创技术文章第一时间推送。

欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。


Keras入门(七)使用Flask+Keras-bert构建模型预测服务
https://percent4.github.io/Keras入门(七)使用Flask-Keras-bert构建模型预测服务/
作者
Jclian91
发布于
2023年8月8日
许可协议