在文章NLP(十五)让模型来告诉你文本中的时间 中,我们已经学会了如何利用kashgari模块来完成序列标注模型的训练与预测,在本文中,我们将会了解如何tensorflow-serving来部署模型。
在kashgari的官方文档中,已经有如何利用tensorflow-serving来部署模型的说明了,网址为:https://kashgari.bmio.net/advance-use/tensorflow-serving/
。
下面,本文将介绍tensorflow-serving以及如何利用tensorflow-serving来部署kashgari的模型。
tensorflow-serving
TensorFlow Serving 是一个用于机器学习模型 serving
的高性能开源库。它可以将训练好的机器学习模型部署到线上,使用 gRPC
作为接口接受外部调用。更加让人眼前一亮的是,它支持模型热更新与自动模型版本管理。这意味着一旦部署
TensorFlow Serving
后,你再也不需要为线上服务操心,只需要关心你的线下模型训练。
TensorFlow
Serving可以方便我们部署TensorFlow模型,本文将使用TensorFlow
Serving的Docker镜像来使用TensorFlow Serving,安装的命令如下:
1 docker pull tensorflow/serving
工程实践
本项目将演示如何利用tensorflow/serving来部署kashgari中的模型,项目结构如下:
本项目的data来自之前笔者标注的时间数据集,即标注出文本中的时间,采用BIO标注系统。chinese_wwm_ext文件夹为哈工大的预训练模型文件。
model_train.py为模型训练的代码,主要功能是完成时间序列标注模型的训练,完整的代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 import kashgarifrom kashgari import utilsfrom kashgari.corpus import DataReaderfrom kashgari.embeddings import BERTEmbeddingfrom kashgari.tasks.labeling import BiLSTM_CRF_Model train_x, train_y = DataReader().read_conll_format_file('./data/time.train' ) valid_x, valid_y = DataReader().read_conll_format_file('./data/time.dev' ) test_x, test_y = DataReader().read_conll_format_file('./data/time.test' ) bert_embedding = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12' , task=kashgari.LABELING, sequence_length=128 ) model = BiLSTM_CRF_Model(bert_embedding) model.fit(train_x, train_y, valid_x, valid_y, batch_size=16 , epochs=1 ) utils.convert_to_saved_model(model, model_path='saved_model/time_entity' , version=1 )
运行该代码,模型训练完后会生成saved_model文件夹,里面含有模型训练好后的文件,方便我们利用tensorflow/serving进行部署。接着我们利用tensorflow/serving来完成模型的部署,命令如下:
1 docker run -t --rm -p 8501:8501 -v "/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/" -e MODEL_NAME =time_entity tensorflow/serving
其中需要注意该模型所在的路径,路径需要写完整路径,以及模型的名称(MODEL_NAME),这在训练代码(train.py)中已经给出(saved_model/time_entity)。
接着我们使用tornado来搭建HTTP服务,帮助我们方便地进行模型预测,runServer.py的完整代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 import requestsfrom kashgari import utilsimport numpy as npfrom model_predict import get_predictimport jsonimport tornado.httpserverimport tornado.ioloopimport tornado.optionsimport tornado.webfrom tornado.options import define, optionsimport tracebackimport tornado.webimport tornado.genimport tornado.concurrentfrom concurrent.futures import ThreadPoolExecutor define("port" , default=16016 , help ="run on the given port" , type =int )class ModelPredictHandler (tornado.web.RequestHandler): executor = ThreadPoolExecutor(max_workers=5 ) @tornado.gen.coroutine def get (self ): origin_text = self.get_argument('text' ) result = yield self.function(origin_text) self.write(json.dumps(result, ensure_ascii=False )) @tornado.concurrent.run_on_executor def function (self, text ): try : text = text.replace(' ' , '' ) x = [_ for _ in text] processor = utils.load_processor(model_path='saved_model/time_entity/1' ) tensor = processor.process_x_dataset([x]) tensor = [{ "Input-Token:0" : i.tolist(), "Input-Segment:0" : np.zeros(i.shape).tolist() } for i in tensor] r = requests.post("http://localhost:8501/v1/models/time_entity:predict" , json={"instances" : tensor}) preds = r.json()['predictions' ] labels = processor.reverse_numerize_label_sequences(np.array(preds).argmax(-1 )) entities = get_predict('TIME' , text, labels[0 ]) return entities except Exception: self.write(traceback.format_exc().replace('\n' , '<br>' ))class HelloHandler (tornado.web.RequestHandler): def get (self ): self.write('Hello from lmj from Daxing Beijing!' )def main (): tornado.options.parse_command_line() app = tornado.web.Application( handlers=[(r'/model_predict' , ModelPredictHandler), (r'/hello' , HelloHandler), ], ) http_server = tornado.httpserver.HTTPServer(app) http_server.listen(options.port) tornado.ioloop.IOLoop.instance().start() main()
我们定义了tornado封装HTTP服务来进行模型预测,运行该脚本,启动模型预测的HTTP服务。接着我们再使用Python脚本才测试下模型的预测效果以及预测时间,预测的代码脚本的完整代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 import timeimport jsonimport requests t1 = time.time() texts = ['记者从国家发展改革委、商务部相关方面获悉,日前美方已决定对拟于10月1日实施的中国输美商品加征关税措施做出调整,中方支持相关企业从即日起按照市场化原则和WTO规则,自美采购一定数量大豆、猪肉等农产品,国务院关税税则委员会将对上述采购予以加征关税排除。' , '据印度Zee新闻网站12日报道,亚洲新闻国际通讯社援引印度军方消息人士的话说,9月11日的对峙事件发生在靠近班公错北岸的实际控制线一带。' , '儋州市决定,从9月开始,对城市低保、农村低保、特困供养人员、优抚对象、领取失业保险金人员、建档立卡未脱贫人口等低收入群体共3万多人,发放猪肉价格补贴,每人每月发放不低于100元补贴,以后发放标准,将根据猪肉价波动情况进行动态调整。' , '9月11日,华为心声社区发布美国经济学家托马斯.弗里德曼在《纽约时报》上的专栏内容,弗里德曼透露,在与华为创始人任正非最近一次采访中,任正非表示华为愿意与美国司法部展开话题不设限的讨论。' , '造血干细胞移植治疗白血病技术已日益成熟,然而,通过该方法同时治愈艾滋病目前还是一道全球尚在攻克的难题。' , '英国航空事故调查局(AAIB)近日披露,今年2月6日一趟由德国法兰克福飞往墨西哥坎昆的航班上,因飞行员打翻咖啡使操作面板冒烟,导致飞机折返迫降爱尔兰。' , '当地时间周四(9月12日),印度尼西亚财政部长英卓华(Sri Mulyani Indrawati)明确表示:特朗普的推特是风险之一。' , '华中科技大学9月12日通过其官方网站发布通报称,9月2日,我校一硕士研究生不幸坠楼身亡。' , '微博用户@ooooviki 9月12日下午公布发生在自己身上的惊悚遭遇:一个自称网警、名叫郑洋的人利用职务之便,查到她的完备的个人信息,包括但不限于身份证号、家庭地址、电话号码、户籍变动情况等,要求她做他女朋友。' , '今天,贵阳取消了汽车限购,成为目前全国实行限购政策的9个省市中,首个取消限购的城市。' , '据悉,与全球同步,中国区此次将于9月13日于iPhone官方渠道和京东正式开启预售,京东成Apple中国区唯一官方授权预售渠道。' , '根据央行公布的数据,截至2019年6月末,存款类金融机构住户部门短期消费贷款规模为9.11万亿元,2019年上半年该项净增3293.19亿元,上半年增量看起来并不乐观。' , '9月11日,一段拍摄浙江万里学院学生食堂的视频走红网络,视频显示该学校食堂不仅在用餐区域设置了可以看电影、比赛的大屏幕,还推出了“一人食”餐位。' , '当日,在北京举行的2019年国际篮联篮球世界杯半决赛中,西班牙队对阵澳大利亚队。' , ]print (len (texts))for text in texts: url = 'http://localhost:16016/model_predict?text=%s' % text req = requests.get(url) print (json.loads(req.content)) t2 = time.time()print (round (t2-t1, 4 ))
运行该代码,输出的结果如下:(预测文本中的时间)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 一共预测14 个句子。 ['日前' , '10月1日' , '即日' ] ['12日' , '9月11日' ] ['9月' ] ['9月11日' ] [] ['近日' , '今年2月6日' ] ['当地时间周四(9月12日)' ] ['9月12日' , '9月2日' ] ['9月12日下午' ] ['今天' , '目前' ] ['9月13日' ] ['2019年6月末' , '2019年上半年' , '上半年' ] ['9月11日' ] ['当日' , '2019年' ] 预测耗时: 15.1085 s.
模型预测的效果还是不错的,但平均每句话的预测时间为1秒多,模型预测时间还是稍微偏长,后续笔者将会研究如何缩短模型预测的时间。
总结
本项目主要是介绍了如何利用tensorflow-serving部署kashgari模型,该项目已经上传至github,地址为:https://github.com/percent4/tensorflow-serving_4_kashgari
。
至于如何缩短模型预测的时间,笔者还需要再继续研究,欢迎大家关注~
欢迎关注我的公众号
NLP奇幻之旅 ,原创技术文章第一时间推送。
欢迎关注我的知识星球“自然语言处理奇幻之旅 ”,笔者正在努力构建自己的技术社区。