在文章NLP(二十)利用BERT实现文本二分类 中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子。但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题。因此,我们考虑用新出来的预训练模型来加快模型预测速度。
本文将介绍如何利用ALBERT来实现文本二分类。
关于ALBERT
ALBERT的提出时间大约是在2019年10月,其第一作者为谷歌科学家蓝振忠博士。ALBERT的论文地址为:https://openreview.net/pdf?id=H1eA7AEtvS
, Github项目地址为: https://github.com/brightmart/albert_zh
。
简单说来,ALBERT是BERT的一个精简版,它在BERT模型的基础上进行改造,减少了大量参数,使得其在模型训练和模型预测的速度上有很大提升,而模型的效果只会有微小幅度的下降,具体的效果和速度方面的说明可以参考Github项目。
ALBERT相对于BERT的改进如下:
项目说明
本项目的数据和代码主要参考笔者的文章NLP(二十)利用BERT实现文本二分类 ,该项目是想判别输入的句子是否属于政治上的出访类事件。笔者一共收集了340条数据,其中280条用作训练集,60条用作测试集。
项目结构如下图:
项目结构
在这里我们使用ALBERT已经训练好的文件albert_tiny
,借鉴BERT的调用方法,我们在这里给出albert_zh
模块,能够让ALBERT提取文本的特征,具体代码不在这里给出,有兴趣的读者可以访问该项目的Github地址:https://github.com/percent4/ALBERT_text_classification
。
注意,albert_tiny
给出的向量维度为312,我们的模型训练代码(model_train.py)如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 import osimport numpy as npfrom load_data import train_df, test_dffrom keras.utils import to_categoricalfrom keras.models import Modelfrom keras.optimizers import Adamfrom keras.layers import Input, BatchNormalization, Denseimport matplotlib.pyplot as pltfrom albert_zh.extract_feature import BertVector bert_model = BertVector(pooling_strategy="REDUCE_MEAN" , max_seq_len=100 )print ('begin encoding' ) f = lambda text: bert_model.encode([text])["encodes" ][0 ] train_df['x' ] = train_df['text' ].apply(f) test_df['x' ] = test_df['text' ].apply(f)print ('end encoding' ) x_train = np.array([vec for vec in train_df['x' ]]) x_test = np.array([vec for vec in test_df['x' ]]) y_train = np.array([vec for vec in train_df['label' ]]) y_test = np.array([vec for vec in test_df['label' ]])print ('x_train: ' , x_train.shape) num_classes = 2 y_train = to_categorical(y_train, num_classes) y_test = to_categorical(y_test, num_classes) x_in = Input(shape=(312 , )) x_out = Dense(32 , activation="relu" )(x_in) x_out = BatchNormalization()(x_out) x_out = Dense(num_classes, activation="softmax" )(x_out) model = Model(inputs=x_in, outputs=x_out)print (model.summary()) model.compile (loss='categorical_crossentropy' , optimizer=Adam(), metrics=['accuracy' ]) history = model.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=8 , epochs=20 ) model.save('visit_classify.h5' )print (model.evaluate(x_test, y_test)) plt.subplot(2 , 1 , 1 ) epochs = len (history.history['loss' ]) plt.plot(range (epochs), history.history['loss' ], label='loss' ) plt.plot(range (epochs), history.history['val_loss' ], label='val_loss' ) plt.legend() plt.subplot(2 , 1 , 2 ) epochs = len (history.history['acc' ]) plt.plot(range (epochs), history.history['acc' ], label='acc' ) plt.plot(range (epochs), history.history['val_acc' ], label='val_acc' ) plt.legend() plt.savefig("loss_acc.png" )
模型训练的效果很不错,在训练集的acc为0.9857,在测试集上的acc为0.9500,具体如下:
训练过程中的loss和acc图
与BERT的预测对比
接下来我们在模型预测上的时间,与BERT的文本二分类模型预测时间做一个对比,这样有助于提升我们对ALBERT的印象。
BERT的文本二分类模型预测可以参考文章NLP(二十)利用BERT实现文本二分类 ,本文给出的代码与BERT实现的模型预测代码基本一致,只不过BERT提取特征改成ALBERT提取特征。
本文的模型预测代码(model_predict.py)如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 import timeimport pandas as pdimport numpy as npfrom albert_zh.extract_feature import BertVectorfrom keras.models import load_model load_model = load_model("visit_classify.h5" ) texts = ['在访问限制中,用户可以选择禁用iPhone的功能,包括Siri、iTunes购买功能、安装/删除应用等,甚至还可以让iPhone变成一台功能手机。以下是访问限制具体可以实现的一些功能' , 'IT之家4月23日消息 近日,谷歌在其官方论坛发布消息表示,他们为Android Auto添加了一项新功能:可以访问完整联系人列表。用户现在可以通过在Auto的电话拨号界面中打开左上角的菜单访问完整的联系人列表。值得注意的是,这一功能仅支持在车辆停止时使用。' , '要通过telnet 访问路由器,需要先通过console 口对路由器进行基本配置,例如:IP地址、密码等。' , 'IT之家3月26日消息 近日反盗版的国际咨询公司MUSO发布了2017年的年度报告,其中的数据显示,去年盗版资源网站访问量达到了3000亿次,比前一年(2016年)提高了1.6%。美国是访问盗版站点次数最多的国家,共有279亿次访问;其后分别是俄罗斯、印度和巴西,中国位列第18。' , '目前A站已经恢复了访问,可以直接登录,网页加载正常,视频已经可以正常播放。' , 'Win7电脑提示无线适配器或访问点有问题怎么办?很多用户在使用无线网连接上网时,发现无线网显示已连接,但旁边却出现了一个黄色感叹号,无法进行网络操作,通过诊断提示电脑无线适配器或访问点有问题,且处于未修复状态,这该怎么办呢?下面小编就和大家分享下Win7电脑提示无线适配器或访问点有问题的解决方法。' , '未开发所有安全组之前访问,FTP可以链接上,但是打开会很慢,需要1-2分钟才能链接上' , 'win7系统电脑的用户,在连接WIFI网络网上时,有时候会遇到突然上不了网,查看连接的WIFI出现“有限的访问权限”的文字提示。' , '2月28日,唐山曹妃甸蓝色海洋科技有限公司董事长赵力军等一行5人到黄海水产研究所交流访问。黄海水产研究所副所长辛福言及相关部门负责人、专家等参加了会议。' , '与标准Mozy一样,Stash文件夹为用户提供了对其备份文件的基于云的访问,但是它们还使他们可以随时,跨多个设备(包括所有计算机,智能手机和平板电脑)访问它们。换句话说,使用浏览器的任何人都可以同时查看文件(如果需要)。操作系统和设备品牌无关。' , '研究表明,每个网页的平均预期寿命为44至100天。当用户通过浏览器访问已消失的网页时,就会看到「Page Not Found」的错误信息。对于这种情况,相信大多数人也只能不了了之。不过有责任心的组织——互联网档案馆为了提供更可靠的Web服务,它联手Brave浏览器专门针对此类网页提供了一键加载存档页面的功能。' , '3日,根据三星电子的消息,李在镕副会长这天访问了位于韩国庆尚北道龟尾市的三星电子工厂。' ] * 10 labels = [] bert_model = BertVector(pooling_strategy="REDUCE_MEAN" , max_seq_len=100 ) init_time = time.time()for text in texts: vec = bert_model.encode([text])["encodes" ][0 ] x_train = np.array([vec]) predicted = load_model.predict(x_train) y = np.argmax(predicted[0 ]) label = 'Y' if y else 'N' labels.append(label) cost_time = time.time() - init_timeprint ("Average cost time: %s." % (cost_time/len (texts)))for text, label in zip (texts, labels): print ('%s\t%s' % (label, text)) df = pd.DataFrame({'句子' :texts, "是否属于出访类事件" : labels}) df.to_excel('./result.xlsx' , index=False )
输出的平均预测时长为:16.98ms
,而BERT版的平均预测时间为:257.31ms
。
我们将模型预测写成HTTP服务,代码(server.py)如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 import tornado.httpserverimport tornado.ioloopimport tornado.optionsimport tornado.webfrom tornado.options import define, optionsimport jsonimport numpy as npfrom albert_zh.extract_feature import BertVectorfrom keras.models import load_model define("port" , default=10008 , help ="run on the given port" , type =int ) bert_model = BertVector(pooling_strategy="REDUCE_MEAN" , max_seq_len=100 ) load_model = load_model("visit_classify.h5" )class PredictHandler (tornado.web.RequestHandler): def post (self ): text = self.get_argument("text" ) vec = bert_model.encode([text])["encodes" ][0 ] x_train = np.array([vec]) predicted = load_model.predict(x_train) y = np.argmax(predicted[0 ]) label = '是' if y else "否" result = {"原文" : text, "是否属于出访类事件?" : label} self.write(json.dumps(result, ensure_ascii=False , indent=2 ))def main (): tornado.options.parse_command_line() app = tornado.web.Application( handlers=[(r'/predict' , PredictHandler)] ) http_server = tornado.httpserver.HTTPServer(app) http_server.listen(options.port) tornado.ioloop.IOLoop.instance().start() main()
用Postman进行测试,如下图:
实践证明,用ALBERT做文本特征提取,模型训练的效果基本与BERT差别微小,模型训练速度明显提升,更重要的是,模型预测的速度只有BERT版本的6.6%(不同情况下可能有略微差异),这在生产上是十分有帮助的。
参考网址
中文预训练ALBERT模型来了:小模型登顶GLUE,Base版模型小10倍速度快1倍:
https://zhuanlan.zhihu.com/p/85037097
ALBERT一作蓝振忠:预训练模型应用已成熟,ChineseGLUE要对标GLUE基准:https://tech.sina.com.cn/roll/2019-11-17/doc-iihnzhfy9804802.shtml
。
解读ALBERT:https://blog.csdn.net/weixin_37947156/article/details/101529943
。
ALBERT的Github项目地址:https://github.com/brightmart/albert_zh
。
欢迎关注我的公众号
NLP奇幻之旅 ,原创技术文章第一时间推送。
欢迎关注我的知识星球“自然语言处理奇幻之旅 ”,笔者正在努力构建自己的技术社区。