NLP(四十五)R-BERT在人物关系分类上的尝试及Keras代码复现

本文将介绍关系分类模型R-BERT和该模型在人物关系数据集上的表现,以及该模型的Keras代码复现。

关系分类任务

关系分类属于NLP任务中的文本分类,不同之处在于,关系分类提供了文本和实体。比如下面的例子:

亲戚 1837年6月20日,威廉四世辞世,他的侄女维多利亚即位。

其中两个实体在文本中用包围着,人物关系为亲戚。

在关系分类中,我们要注重文本特征,更要留意实体特征。常见的英文关系分类的数据集为SemEval 2010 Task 8、New York Times Corpus、WikiData dataset for Sentential Relation Extraction、NYT29、NYT24等,中文的关系分类数据集比较少,而且质量不高。

关于SemEval 2010 Task 8数据集的实现模型及效果,可以参考:http://nlpprogress.com/english/relationship_extraction.html, 其中常见的实现模型如下:

  • Machince Learning: SVM, Word2Vec ...

  • Dependency Models: BRCNN, DRNN ...

  • CNN-based Models: Multi-Attention CNN, Attention CNN, PCNN+ATT ...

  • BERT-based Models: R-BERT, Matching-the-Blanks ...

    本文将介绍R-BERT模型。

模型介绍

R-BERT模型是Alibaba Group (U.S.) Inc的两位研究者在2019年5月的论文Enriching Pre-trained Language Model with Entity Information for Relation Classification,该模型在SemEval 2010 Task 8数据集上的F1值为89.25%,只比现有的SOTA模型低了0.25%。

R-BERT很好地融合了文本特征以及两个实体在文本中的特征,简单来说,该模型主要是BERT模型中的三个向量的融合:

  • [CLS]对应的向量

  • 实体1的平均向量

  • 实体2的平均向量

    下面将详细讲解R-BERT的具体模型结构。

模型结构

R-BERT的具体模型结构如下图:

R-BERT模型结构图

一图胜千言。从上述的模型结构图中,我们将模型结构分解步骤如下:

  1. 将文本接入BERT模型,获取[CLS] token的对应向量、实体1的在BERT输出层中的平均向量、实体2的在BERT输出层中的平均向量;
  2. 将上述三个向量分别接Drouput层、Tanh激活层以及全连接层;
  3. 再将步骤2输出的三个向量进行拼接(concatenate);
  4. 最后接Dropout层和全连接层,用Softmax作为多分类的激活函数。

此外,需要注意的是,输入文本中没有[SEP]这个token。

论文中并没有给出更多的实现细节,需要深入到代码中去查看。网上已经有人给出了Torch框架的实现R-BERT的代码,参考网址为:https://github.com/monologg/R-BERT。

Torch实现

Torch框架的实现R-BERT的代码(模型部分)如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel


class FCLayer(nn.Module):
def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=True):
super(FCLayer, self).__init__()
self.use_activation = use_activation
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, output_dim)
self.tanh = nn.Tanh()

def forward(self, x):
x = self.dropout(x)
if self.use_activation:
x = self.tanh(x)
return self.linear(x)


class RBERT(BertPreTrainedModel):
def __init__(self, config, args):
super(RBERT, self).__init__(config)
self.bert = BertModel(config=config) # Load pretrained bert

self.num_labels = config.num_labels

self.cls_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
self.entity_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
self.label_classifier = FCLayer(
config.hidden_size * 3,
config.num_labels,
args.dropout_rate,
use_activation=False,
)

@staticmethod
def entity_average(hidden_output, e_mask):
"""
Average the entity hidden state vectors (H_i ~ H_j)
:param hidden_output: [batch_size, j-i+1, dim]
:param e_mask: [batch_size, max_seq_len]
e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
:return: [batch_size, dim]
"""
e_mask_unsqueeze = e_mask.unsqueeze(1) # [b, 1, j-i+1]
length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) # [batch_size, 1]

# [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)
avg_vector = sum_vector.float() / length_tensor.float() # broadcasting
return avg_vector

def forward(self, input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask):
outputs = self.bert(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
) # sequence_output, pooled_output, (hidden_states), (attentions)
sequence_output = outputs[0]
pooled_output = outputs[1] # [CLS]

# Average
e1_h = self.entity_average(sequence_output, e1_mask)
e2_h = self.entity_average(sequence_output, e2_mask)

# Dropout -> tanh -> fc_layer (Share FC layer for e1 and e2)
pooled_output = self.cls_fc_layer(pooled_output)
e1_h = self.entity_fc_layer(e1_h)
e2_h = self.entity_fc_layer(e2_h)

# Concat -> fc_layer
concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1)
logits = self.label_classifier(concat_h)

outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here

# Softmax
if labels is not None:
if self.num_labels == 1:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

outputs = (loss,) + outputs

return outputs

该项目是在SemEval 2010 Task 8数据集实现的,笔者将其在自己的人物关系分类数据集上进行测试,最终在测试集上的评估结果如下:

1
2
# Model: chinese-roberta-wwm-ext, weighted avgage F1 = 85.35%
# Model: chinese-roberta-wwm-ext-large, weighted avgage F1 = 87.22%

Model: chinese-roberta-wwm-ext-large, 详细的评估结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
                precision    recall  f1-score   support

unknown 0.8756 0.8421 0.8585 209
上下级 0.7297 0.8710 0.7941 31
亲戚 0.8421 0.6667 0.7442 24
兄弟姐妹 0.8333 0.8824 0.8571 34
合作 0.9074 0.8305 0.8673 59
同人 0.9744 0.9744 0.9744 39
同学 0.9130 0.8750 0.8936 24
同门 0.9630 1.0000 0.9811 26
夫妻 0.8372 0.9114 0.8727 79
好友 0.8438 0.9000 0.8710 30
师生 0.8378 0.8378 0.8378 37
情侣 0.8125 0.8387 0.8254 31
父母 0.8931 0.9141 0.9035 128
祖孙 0.9545 0.8400 0.8936 25

accuracy 0.8724 776
macro avg 0.8727 0.8703 0.8696 776
weighted avg 0.8743 0.8724 0.8722 776

R-BERT模型在人物关系数据集上的Github项目为R-BERT_for_people_relation_extraction 。下面将介绍R-BERT模型的Keras框架复现。

Keras复现

R-BERT模型的Keras框架复现(模型部分)的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
# main architecture of R-BERT
from keras.models import Model
from keras.utils import plot_model
from keras.layers import Input, Lambda, Dense, Dropout, concatenate, Dot
from keras_bert import load_trained_model_from_checkpoint


# model structure of R-BERT
class RBERT(object):
def __init__(self, config_path, checkpoint_path, maxlen, num_labels):
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.maxlen = maxlen
self.num_labels = num_labels

def create_model(self):
# BERT model
bert_model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path, seq_len=None)
for layer in bert_model.layers:
layer.trainable = True
x1_in = Input(shape=(self.maxlen,))
x2_in = Input(shape=(self.maxlen,))
bert_layer = bert_model([x1_in, x2_in])

# get three vectors
cls_layer = Lambda(lambda x: x[:, 0])(bert_layer) # 取出[CLS]对应的向量
e1_mask = Input(shape=(self.maxlen,))
e2_mask = Input(shape=(self.maxlen,))
e1_layer = self.entity_average(bert_layer, e1_mask) # 取出实体1对应的向量
e2_layer = self.entity_average(bert_layer, e2_mask) # 取出实体2对应的向量

# dropout -> linear -> concatenate
output_dim = cls_layer.shape[-1].value
cls_fc_layer = self.crate_fc_layer(cls_layer, output_dim, dropout_rate=0.1)
e1_fc_layer = self.crate_fc_layer(e1_layer, output_dim, dropout_rate=0.1)
e2_fc_layer = self.crate_fc_layer(e2_layer, output_dim, dropout_rate=0.1)
concat_layer = concatenate([cls_fc_layer, e1_fc_layer, e2_fc_layer], axis=-1)

# FC layer for classification
output = Dense(self.num_labels, activation="softmax")(concat_layer)
model = Model([x1_in, x2_in, e1_mask, e2_mask], output)
model.summary()
return model

@staticmethod
def crate_fc_layer(input_layer, output_dim, dropout_rate=0.0, activation_func="tanh"):
dropout_layer = Dropout(rate=dropout_rate)(input_layer)
linear_layer = Dense(output_dim, activation=activation_func)(dropout_layer)
return linear_layer

@staticmethod
def entity_average(hidden_output, e_mask):
"""
Average the entity hidden state vectors (H_i ~ H_j)
:param hidden_output: BERT hidden output
:param e_mask:
e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]/num_of_ones
:return: entity average layer
"""
avg_layer = Dot(axes=1)([e_mask, hidden_output])
return avg_layer

总结

R-BERT模型再次见证了BERT等预训练模型的强大。该模型的实现思路比较简单,也取得了很不错的效果,是关系分类任务的一大突破。

当然对笔者来说,也有种重要的意义:第一次自己复现了论文代码,虽然有Torch代码可以参考。

本文分享到此结束,感谢阅读~

2021年4月1日于上海杨浦,此日大雾迷城~

参考文献

  • NLP-progress Relation Extraction: http://nlpprogress.com/english/relationship_extraction.html
  • Huggingface Transformers: https://github.com/huggingface/transformers
  • https://github.com/wang-h/bert-relation-classification
  • R-BERT: https://github.com/monologg/R-BERT
  • Enriching Pre-trained Language Model with Entity Information for Relation Classification: https://arxiv.org/pdf/1905.08284.pdf
  • Chinese-BERT-wwm: https://github.com/ymcui/Chinese-BERT-wwm
欢迎关注我的公众号NLP奇幻之旅,原创技术文章第一时间推送。

欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。


NLP(四十五)R-BERT在人物关系分类上的尝试及Keras代码复现
https://percent4.github.io/NLP(四十五)R-BERT在人物关系分类上的尝试及Keras代码复现/
作者
Jclian91
发布于
2023年7月10日
许可协议