NLP(九十八)基于PDF文档的多模态问答

本文将会介绍基于PDF文档的多模态问答,包括PDF中的表格、图片和文字。

在以往的PDF系列文章NLP(八十九)PDF文档智能问答入门NLP(九十一)PDF表格问答中,笔者分别介绍了如何对PDF中的纯文本和表格(保存为图片形式)数据,使用大模型进行智能问答。但PDF文件一般包含文字、图片、表格等多种模态的数据,因此,多模态智能问答显得尤为重要。

本文是笔者关于多模态智能问答方面的一些思考,主要介绍基于纯文本的PDF文档(非扫描版)的多模态问答,希望能给读者带来一些启发。

本文将会分以下几方面进行介绍:

  • 获取PDF中的表格、图片,统一保存为图片格式
  • 获取表格、图片和其对应标题的映射关系
  • 对文本、表格、图片进行向量嵌入(Embedding)
  • 基于用户问题,获取关联度最高的文本片段、表格、图片数据
  • 使用多模态大模型(比如GPT-4V)进行回答

预处理:PDF中的表格、图片获取

首先,我们使用百度开源的版面分析工具PaddleOCR来获取PDF文档中的表格、图片所在的矩形坐标区域,并将它们保存为图片,实现的Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: layout_analysis.py
# @time: 2024/4/3 10:52
# use PP-Structure V2 to get layout analysis for PDF
import os
import re
import json
import cv2
import fitz
from PIL import Image
from paddleocr import PPStructure, save_structure_res


class LayoutAnalysis(object):
def __init__(self, pdf_file_path, save_folder="../output"):
self.pdf_file_path = pdf_file_path
self.save_folder = save_folder
self.file_name = os.path.basename(self.pdf_file_path).split('.')[0]
self.save_dir = os.path.join(self.save_folder, f"{self.file_name}")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)

def _convert_to_img(self):
pdf_document = fitz.open(self.pdf_file_path)
image_path_list = []
# Iterate through each page and convert to an image
for page_number in range(pdf_document.page_count):
# Get the page
page = pdf_document[page_number]
# Convert the page to an image
pix = page.get_pixmap()
# Create a Pillow Image object from the pixmap
image = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
# Save the image
image_path = os.path.join(self.save_dir, f"{self.file_name}_{page_number + 1}.jpg")
image_path_list.append(image_path)
image.save(image_path, "JPEG", quality=95)
# Close the PDF file
pdf_document.close()
return image_path_list

def image_parse(self):
image_path_list = self._convert_to_img()
table_engine = PPStructure(table=False, ocr=False, show_log=True)
for img_idx, img_path in enumerate(image_path_list):
print(f"Layout analysis for {img_idx+1} image with path: {img_path}")
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, self.save_folder, self.file_name, img_idx=img_idx+1)
for line in result:
line.pop('img')
print(line)

@staticmethod
# 解析每页PDF中的表格,保存为图片
def parse_table(pdf_page_image, pdf_res_txt):
dir_name = os.path.dirname(pdf_page_image)
page_number = re.findall("\d+", pdf_page_image)[-1]
with open(pdf_res_txt, 'r') as f:
content = [json.loads(_.strip()) for _ in f.readlines()]

table_cnt = 1
for line in content:
rect_type = line["type"]
region = line["bbox"]
# 将表格保存为图片
if rect_type == "table":
with Image.open(pdf_page_image).convert('RGB') as image:
region_img = image.crop(region)
save_image_path = f"{dir_name}/{page_number}_{table_cnt}_table.jpg"
print(f"save table to {save_image_path}")
region_img.save(save_image_path, 'JPEG', quality=95)
table_cnt += 1

# 解析版面分析后的PDF结果的文件夹
def tables_2_images(self):
for file in os.listdir(self.save_dir):
if file.startswith(self.file_name):
res_txt = file.replace(self.file_name, "res").replace("jpg", "txt")
pdf_page_image_path = os.path.join(self.save_dir, file)
pdf_res_txt_path = os.path.join(self.save_dir, res_txt)
self.parse_table(pdf_page_image=pdf_page_image_path,
pdf_res_txt=pdf_res_txt_path)

def run(self):
self.image_parse()
self.tables_2_images()


if __name__ == '__main__':
pdf_path = "../data/LLaMA.pdf"
layout_analyzer = LayoutAnalysis(pdf_file_path=pdf_path)
layout_analyzer.run()

以LLaMA论文PDF(访问网址为:https://arxiv.org/abs/2302.13971)为例,得到的表格、图片文件如下:

multi-modal-pdf-qa-1.png

其中,表格文件的命名规律为页数_该页中表格序号_table.jpg,而图片文件的则以改图片的左上角、右下角坐标为开头。

我们来看下5_1_table.jpg表格文件:

2_1_table.jpg

再来看看[67, 68, 527, 349]_8.jpg图片文件:

_67, 68, 527, 349__8.jpg

预处理:获取表格、图片对应的表格

接下来,我们对上述预处理得到的表格、图片文件,获取它们各自对应的标题,实现的方法是将表格、图片文件所对应的矩形区域,与文本所在区域的中心点进行欧式距离计算,取最短距离所在的文本区域即为其标题。

注意:该匹配方法并不是最优的算法,因此这方面的匹配算法还有待提升。

Python实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: get_table_figure_caption.py
# @time: 2024/3/28 10:38
# use fitz to get table and its caption from pdf file
import math
import os
import json
import re
from operator import itemgetter
import fitz
# 为图片或表格匹配对应的caption


class TableFigureMatch(object):
def __init__(self, pdf_file_path, save_folder="../output"):
self.pdf_file_path = pdf_file_path
self.save_folder = save_folder
self.file_name = os.path.basename(self.pdf_file_path).split('.')[0]
self.res_dir = os.path.join(self.save_folder, f"{self.file_name}")

# get table or figure caption in each PDF page
def get_caption_by_page(self, data_type):
"""
:param data_type: str, data type from pdf, enumerate: table or figure
:return: page_dict_list: list[dict], data type dict in each page in list
"""
assert data_type in ['table', 'figure']
doc = fitz.open(self.pdf_file_path)
page_number = doc.page_count
page_dict_list = []
for i in range(page_number):
table_page_dict = {}
page = doc[i]
page_dict = page.get_text("dict", sort=True)
for block in page_dict['blocks']:
bbox = block['bbox']
text_list = []
if 'lines' in block:
for spans in block['lines']:
for span in spans['spans']:
text_list.append(span['text'])
text = ' '.join(text_list)
if text.lower().startswith(data_type):
table_page_dict['-'.join([str(x) for x in bbox])] = text
page_dict_list.append(table_page_dict)
doc.close()
return page_dict_list

@staticmethod
def find_rect_center(rect):
return (rect[0]+rect[2])/2, (rect[1]+rect[3])/2

@staticmethod
def get_euclid_distance(x0, y0, x1, y1):
return math.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2)

def find_neighbor_rect(self, rect1, rect_list):
if len(rect_list) == 1:
return rect_list[0]
center_x_rect1, center_y_rect1 = self.find_rect_center(rect1)
center_rect_list = []
for rect in rect_list:
center_x, center_y = self.find_rect_center(rect)
center_rect_list.append((center_x, center_y))

distance_dict = {}
for i, center in enumerate(center_rect_list):
distance = self.get_euclid_distance(center_x_rect1, center_y_rect1, center[0], center[1])
distance_dict[i] = distance
distance_sort_list = sorted(distance_dict.items(), key=itemgetter(1))
return rect_list[distance_sort_list[0][0]]

def match_caption_by_page(self, data_type):
data_type_caption_dict = {}
page_dict_list = self.get_caption_by_page(data_type=data_type)
for file in os.listdir(self.res_dir):
if file.startswith("res"):
pg_num = int(re.findall('\d+', file)[0])
res_txt_file_path = os.path.join(self.res_dir, file)
with open(res_txt_file_path, 'r') as f:
content = [json.loads(_.strip()) for _ in f.readlines()]
cnt = 1
for line in content:
rect_type, pdf_rect_bbox = line['type'], line['bbox']
if rect_type == data_type and page_dict_list[pg_num-1]:
caption_rect_list = [[float(x) for x in _.split('-')]for _ in page_dict_list[pg_num-1].keys()]
neighbor_rect = self.find_neighbor_rect(rect1=pdf_rect_bbox, rect_list=caption_rect_list)
if data_type == 'table':
file_path = f'{pg_num}_{cnt}_table.jpg'
cnt += 1
else:
file_path = f"[{', '.join([str(_) for _ in pdf_rect_bbox])}]_{pg_num}.jpg"
data_type_caption_dict[file_path] = page_dict_list[pg_num-1]['-'.join([str(_) for _ in neighbor_rect])]
return data_type_caption_dict

def run(self):
data_type_dict = {}
for data_type in ['table', 'figure']:
data_type_caption_dict = self.match_caption_by_page(data_type=data_type)
data_type_dict.update(data_type_caption_dict)

with open(os.path.join(self.res_dir, "table_figure_caption.json"), "w") as f:
f.write(json.dumps(data_type_dict, ensure_ascii=False, indent=4))


if __name__ == '__main__':
pdf_path = '../data/LLaMA.pdf'
table_figure_matcher = TableFigureMatch(pdf_file_path=pdf_path)
table_figure_matcher.run()

以LLaMA论文PDF为例,得到的表格、标题对应的标题映射JSON文件如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
{
"3_1_table.jpg": "Table 2: Model sizes, architectures, and optimization hyper-parameters.",
"2_1_table.jpg": "Table 1: Pre-training data. Data mixtures used for pre- training, for each subset we list the sampling propor- tion, number of epochs performed on the subset when training on 1.4T tokens, and disk size. The pre-training runs on 1T tokens have the same sampling proportion.",
"5_1_table.jpg": "Table 6: Reading Comprehension. Zero-shot accu- racy.",
"5_2_table.jpg": "Table 5: TriviaQA. Zero-shot and few-shot exact match performance on the filtered dev set.",
"4_1_table.jpg": "Table 3: Zero-shot performance on Common Sense Reasoning tasks.",
"4_2_table.jpg": "Table 4: NaturalQuestions. Exact match performance.",
"6_1_table.jpg": "Table 8: Model performance for code generation. We report the pass@ score on HumanEval and MBPP. HumanEval generations are done in zero-shot and MBBP with 3-shot prompts similar to Austin et al. ( 2021 ). The values marked with ∗ are read from figures in Chowdhery et al. ( 2022 ).",
"6_2_table.jpg": "Table 7: Model performance on quantitative reason- ing datasets. For majority voting, we use the same setup as Minerva, with k = 256 samples for MATH and k = 100 for GSM8k (Minerva 540B uses k = 64 for MATH and and k = 40 for GSM8k). LLaMA-65B outperforms Minerva 62B on GSM8k, although it has not been fine-tuned on mathematical data.",
"7_1_table.jpg": "Table 9: Massive Multitask Language Understanding (MMLU). Five-shot accuracy.",
"18_1_table.jpg": "Table 16: MMLU. Detailed 5-shot results per domain on the test sets.",
"11_1_table.jpg": "Table 15: Carbon footprint of training different models in the same data center. We follow Wu et al. ( 2022 ) to compute carbon emission of training OPT, BLOOM and our models in the same data center. For the power consumption of a A100-80GB, we take the thermal design power for NVLink systems, that is 400W. We take a PUE of 1.1 and a carbon intensity factor set at the national US average of 0.385 kg CO 2 e per KWh.",
"10_1_table.jpg": "Table 13: WinoGender. Co-reference resolution ac- curacy for the LLaMA models, for different pronouns (“her/her/she” and “his/him/he”). We observe that our models obtain better performance on “their/them/some- one’ pronouns than on “her/her/she” and “his/him/he’, which is likely indicative of biases.",
"10_2_table.jpg": "Table 14: TruthfulQA. We report the fraction of truth- ful and truthful*informative answers, as scored by spe- cially trained models via the OpenAI API. We follow the QA prompt style used in Ouyang et al. ( 2022 ), and report the performance of GPT-3 from the same paper.",
"9_1_table.jpg": "Table 12: CrowS-Pairs. We compare the level of bi- ases contained in LLaMA-65B with OPT-175B and GPT3-175B. Higher score indicates higher bias.",
"8_1_table.jpg": "Table 11: RealToxicityPrompts. We run a greedy de- coder on the 100k prompts from this benchmark. The “respectful” versions are prompts starting with “Com- plete the following sentence in a polite, respectful, and unbiased manner:”, and “Basic” is without it. Scores were obtained using the PerplexityAPI, with higher score indicating more toxic generations.",
"[305, 193, 526, 341]_3.jpg": "Figure 1: Training loss over train tokens for the 7B, 13B, 33B, and 65 models. LLaMA-33B and LLaMA- 65B were trained on 1.4T tokens. The smaller models were trained on 1.0T tokens. All models are trained with a batch size of 4M tokens.",
"[72, 253, 504, 306]_17.jpg": "Figure 3: Formatted dataset example for Natural Questions (left) & TriviaQA (right).",
"[67, 68, 527, 349]_8.jpg": "Figure 2: Evolution of performance on question answering and common sense reasoning during training."
}

获取Embedding: 文本、表格、图片

接下面,开始正式进入多模态RAG(multi-modal-rag)。

在多模态RAG中,我们需要对各种模态的数据进行向量嵌入,一般包括文本、表格、图片等。在上述的预处理过程中,我们已经得到了PDF文档中的表格、图片,而文字处理较为简单,我们需要得到它们的向量嵌入(Embedding)。

我们使用的向量数据库为Milvus,创建的文本、图片的Schema(文本、图片分开储存)如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: vector_db_schema.py
# @time: 2024/3/28 15:02
import time
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType

image_collection_name = "pdf_image_qa"
text_collection_name = "pdf_text_qa"
# Connects to a server
client = MilvusClient(uri="http://localhost:19530", db_name="default")


def create_schema(collect_name, fields, desc):
schema = CollectionSchema(fields, description=desc)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="embedding",
index_type="IVF_FLAT",
metric_type="IP",
params={"nlist": 128}
)
client.create_collection(
collection_name=collect_name,
schema=schema,
index_params=index_params
)
time.sleep(3)
res = client.get_load_state(
collection_name=collect_name
)
print("load state: ", res)


if not client.has_collection(image_collection_name) and not client.has_collection(text_collection_name):
# Creates an image collection
images_fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="pdf_path", dtype=DataType.VARCHAR, max_length=100),
FieldSchema(name="data_type", dtype=DataType.VARCHAR, max_length=20),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1000),
FieldSchema(name="image_path", dtype=DataType.VARCHAR, max_length=300),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1024)
]
image_collection_desc = "image embedding for pdf file"
create_schema(collect_name=image_collection_name, fields=images_fields, desc=image_collection_desc)
# Creates a text collection
text_fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="pdf_path", dtype=DataType.VARCHAR, max_length=100),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=3000),
FieldSchema(name="page_no", dtype=DataType.INT64),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1536)
]
text_collection_desc = "text embedding for pdf file"
create_schema(collect_name=text_collection_name, fields=text_fields, desc=text_collection_desc)

client.close()

其中图片Schema中的text字段为其对应的标题。

接着,我们使用BAAI/bge-visualized模型来获取图片的Embedding,使用OpenAI的text-embedding-ada-002模型来获取文本片段的Embedding。

首先,我们需要部署BAAI/bge-visualized模型推理服务,笔者在这里给出其中一种方案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: bge_v_embedding_server.py
# @time: 2024/4/2 21:50
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
import base64
from PIL import Image
from io import BytesIO

app = FastAPI()


class MultiModal(BaseModel):
image_base64: str = ""
text: str = ""


model = Visualized_BGE(model_name_bge="BAAI/bge-m3", model_weight="./models/bge-visualized/Visualized_m3.pth")
model.eval()
print("model loaded!")


@app.get('/')
def home():
return 'hello world'


@app.post('/mm_embedding')
def get_mm_embedding(multi_modal: MultiModal):
if multi_modal.image_base64:
with Image.open(BytesIO(base64.b64decode(multi_modal.image_base64))) as im:
image_path = 'tmp.png'
im.save(image_path, 'PNG')
with torch.no_grad():
query_emb = model.encode(image=image_path, text=multi_modal.text)
else:
with torch.no_grad():
query_emb = model.encode(text=multi_modal.text)
print(query_emb)
return {"embedding": query_emb.tolist()[0]}


if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=50074)

接着,我们来获取文本、表格、图片的向量嵌入并存入Milvus数据库中,Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: get_mm_embedding.py
# @time: 2024/4/2 16:03
# 获取图片和表格的多模态embedding
import json
import base64
import os
import fitz
from dotenv import load_dotenv

import requests
from pymilvus import MilvusClient
from langchain.text_splitter import RecursiveCharacterTextSplitter

load_dotenv()


def get_image_base64_str(image_path):
with open(image_path, "rb") as image_file:
data = base64.b64encode(image_file.read()).decode('utf-8')
return data


def get_multi_modal_embedding(image_path=None, text=""):
if image_path is not None:
with open(image_path, "rb") as image_file:
data = base64.b64encode(image_file.read()).decode('utf-8')
else:
data = ""

url = "http://localhost:50074/mm_embedding"
payload = json.dumps({
"image_base64": data, "text": text
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.json()['embedding']


def get_text_embedding(text_chunks):
url = "https://api.openai.com/v1/embeddings"
payload = json.dumps({
"model": "text-embedding-ada-002",
"input": text_chunks,
"encoding_format": "float"
})
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'
}
response = requests.request("POST", url, headers=headers, data=payload)
embedding = [_["embedding"] for _ in response.json()['data']]
response.close()
return embedding


class ImageEmbedding(object):
def __init__(self, pdf_file_path, milvus_client, save_folder="../output"):
self.pdf_file_path = pdf_file_path
self.save_folder = save_folder
self.file_name = os.path.basename(self.pdf_file_path).split('.')[0]
self.res_dir = os.path.join(self.save_folder, f"{self.file_name}")
self.milvus_client = milvus_client

def run(self):
with open(os.path.join(self.res_dir, "table_figure_caption.json"), "r") as f:
table_figure_caption_dict = json.loads(f.read())

entities = []
for file in os.listdir(self.res_dir):
if (file.startswith('[') or 'table' in file) and file.endswith('jpg'):
file_path = os.path.join(self.res_dir, file)
caption = table_figure_caption_dict.get(file, "")
print(f'get embedding for {file_path} with caption: {caption}')
image_embedding = get_multi_modal_embedding(image_path=file_path, text=caption)
data_type = "table" if "table" in file else "image"
entities.append({"pdf_path": self.pdf_file_path,
"data_type": data_type,
"text": caption,
"image_path": file_path,
"embedding": image_embedding})

# Inserts vectors in the collection
self.milvus_client.insert(
collection_name="pdf_image_qa",
data=entities)


class TextEmbedding(object):
def __init__(self, pdf_file_path, milvus_client):
self.pdf_file_path = pdf_file_path
self.milvus_client = milvus_client

def get_texts(self):
text_list = []
doc = fitz.open(self.pdf_file_path)
page_number = doc.page_count
for i in range(page_number):
page = doc[i]
text_list.append(page.get_text())
doc.close()
return text_list

@staticmethod
def get_chunks(text_list):
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=300, chunk_overlap=0, encoding_name="cl100k_base"
)
chunks = []
for page_no, text in enumerate(text_list):
for chunk in text_splitter.split_text(text):
chunks.append((chunk, page_no + 1))
return chunks

def run(self):
text_list = self.get_texts()
chunks = self.get_chunks(text_list=text_list)
batch_size = 10
start_no = 0
chunk_embeddings = []
while start_no < len(chunks):
print(f"start no: {start_no}")
batch_chunk_embeddings = get_text_embedding(
text_chunks=[_[0] for _ in chunks[start_no:start_no + batch_size]])
chunk_embeddings.extend(batch_chunk_embeddings)
start_no += batch_size
entities = [{"pdf_path": self.pdf_file_path,
"text": chunks[i][0],
"page_no": chunks[i][1],
"embedding": chunk_embeddings[i]} for i in range(len(chunks))]
self.milvus_client.insert(collection_name="pdf_text_qa", data=entities)


if __name__ == '__main__':
pdf_path = '../data/LLaMA.pdf'
client = MilvusClient(uri="http://localhost:19530", db_name="default")
my_image_embedding = ImageEmbedding(
pdf_file_path=pdf_path, milvus_client=client)
my_image_embedding.run()
text_embedding = TextEmbedding(
pdf_file_path=pdf_path,
milvus_client=client)
text_embedding.run()
client.close()

召回:获取与query相似性最高的文本、表格和图片

针对用户的问题(query),我们利用向量相似度来得到与之最为接近的文本片段、表格和图片。在这里的图片召回中,我们有两路召回:其中一路为多模态Embedding模型;另一路为规则召回,取文本片段所在的页面中的图片,按文本相似度分数进行获取。最后,对两路召回的图片进行去重,并保证它们的数量不超过规定值,比如10。

相关的召回阶段(Retrieval)的Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: content_retrieval.py
# @time: 2024/4/2 16:42
import os
import json

from pymilvus import MilvusClient

from get_mm_embedding import get_multi_modal_embedding
from get_mm_embedding import get_text_embedding


class ContentRetrieval(object):
def __init__(self, query, milvus_client):
self.query = query
self.milvus_client = milvus_client
self.image_limit_number = 10

def image_retrieval_by_embedding(self):
query_embedding = get_multi_modal_embedding(text=self.query)
res = self.milvus_client.search(
collection_name="pdf_image_qa",
data=[query_embedding],
limit=5,
search_params={"metric_type": "IP", "params": {}},
output_fields=['data_type', 'text', 'image_path']
)
return [_['entity'] for _ in res[0]]

def text_retrieval(self):
query_text_embedding = get_text_embedding(text_chunks=[self.query])
res = self.milvus_client.search(
collection_name="pdf_text_qa",
data=query_text_embedding,
limit=10,
search_params={"metric_type": "IP", "params": {}},
output_fields=['text', 'page_no', 'pdf_path']
)
result = [_['entity'] for _ in res[0]]
pdf_page_no_dict = {}
for record in result:
pdf_path = record['pdf_path']
if pdf_path not in pdf_page_no_dict:
pdf_page_no_dict[pdf_path] = [record['page_no']]
else:
if record['page_no'] not in pdf_page_no_dict[pdf_path]:
pdf_page_no_dict[pdf_path].append(record['page_no'])
return pdf_page_no_dict, result

@staticmethod
def image_retrieval_by_text_page(pdf_page_no_dict: dict):
additional_image_list = []
for pdf_path, page_no_set in pdf_page_no_dict.items():
file_name = os.path.basename(pdf_path).split('.')[0]
output_dir = f"../output/{file_name}"
with open(os.path.join(output_dir, "table_figure_caption.json"), "r") as f:
table_figure_caption_dict = json.loads(f.read())
for page_no in page_no_set:
for file in os.listdir(output_dir):
image_path = os.path.join(output_dir, file)
if file.startswith('[') and f']_{page_no}.jpg' in file:
additional_image_list.append({'data_type': 'image',
'image_path': image_path,
'text': table_figure_caption_dict.get(file, "")})
elif "table" in file and file.startswith(str(page_no)):
additional_image_list.append({'data_type': 'table',
'image_path': image_path,
'text': table_figure_caption_dict.get(file, "")})

return additional_image_list

def run(self):
image_embedding_result = self.image_retrieval_by_embedding()
pdf_page_no_dict, text_result = self.text_retrieval()
image_additional_result = self.image_retrieval_by_text_page(pdf_page_no_dict)
# 图片去重并限制数量
for img_item in image_additional_result:
if img_item not in image_embedding_result and len(image_embedding_result) < self.image_limit_number:
image_embedding_result.append(img_item)
image_result = image_embedding_result[:self.image_limit_number]
print(text_result)
print(pdf_page_no_dict)
print(json.dumps(image_result))
return image_result, text_result


if __name__ == '__main__':
client = MilvusClient(uri="http://localhost:19530", db_name="default")
my_query = "What is LLaMA-7B's zero-shot accuracy on RACE dataset?"
content_retriever = ContentRetrieval(query=my_query, milvus_client=client)
content_retriever.run()
client.close()

QA: 使用多模态模型进行智能问答

最后,我们使用OpenAI的GPT-4V模型,使用上述召回的文本片段、表格、图片数据进行智能问答。这样,我们就试下了Multi-modal RAG,也就是多模态PDF文档的智能问答。

问答部分的Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: mm_doc_qa.py
# @time: 2024/4/2 17:01
import json
import os
import base64

import requests
from dotenv import load_dotenv
from pymilvus import MilvusClient

from content_retrieval import ContentRetrieval


load_dotenv()


class MultiModelQA(object):
def __init__(
self,
query: str,
text_chunks: list[str],
images: list[str],
captions: list[str],
):
self.query = query
self.images = images
self.text_chunks = text_chunks
self.captions = captions

@staticmethod
def encode_image(image_path: str):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')

def make_prompt(self):
# Getting the base64 string
image_content = [
{
"type": "text",
"text": self.query
},
]
for image in self.images:
base64_image = self.encode_image(image)
image_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
# get caption desc
seq_no_list = ['first', 'second', 'third', 'forth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth']
caption_list = []
for i, caption in enumerate(self.captions):
if caption:
caption_list.append(f"The caption of {seq_no_list[i]} image is {caption}.")
caption_desc = '\n'.join(caption_list)
messages = [{"role": "system",
"content": "You are a helpful assistant."},
{"role": "user",
"content": f"<Text from PDF file>:\n\n{''.join(self.text_chunks)}"},
{"role": "user",
"content": caption_desc},
{"role": "user",
"content": image_content}]
# print(json.dumps(messages, ensure_ascii=False))
return messages

@staticmethod
def make_request(messages):
# make request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
}
payload = {"model": "gpt-4-vision-preview",
"messages": messages,
"max_tokens": 500}

response = requests.post(
url="https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload)

return response.json()['choices'][0]['message']['content']

def run(self):
messages = self.make_prompt()
answer = self.make_request(messages=messages)
return answer


if __name__ == '__main__':
client = MilvusClient(uri="http://localhost:19530", db_name="default")
my_query = "What is LLaMA-7B's zero-shot accuracy on RACE dataset?"
content_retriever = ContentRetrieval(query=my_query, milvus_client=client)
image_result, text_result = content_retriever.run()
retrieved_text_chunks = [_['text'] for _ in text_result]
retrieved_images = [_['image_path'] for _ in image_result]
retrieved_captions = [_['text'] for _ in image_result]
mm_qa = MultiModelQA(
query=my_query,
text_chunks=retrieved_text_chunks,
images=retrieved_images,
captions=retrieved_captions)
doc_answer = mm_qa.run()
print(doc_answer)
client.close()

测试

我们以LLAMA的PDF论文为例,进行多模态问答测试。

首先是基于表格或图片(答案来自表格或图片)的问答:

  • query: What is LLaMA-7B's zero-shot accuracy on RACE dataset?

answer:

1
LLaMA-7B's zero-shot accuracy on the RACE dataset is 61.1% for RACE-middle and 46.9% for RACE-high.

  • query: What is LLaMA model's average accuracy on MMLU dataset?

answer:

1
2
3
The LLaMA model's average accuracy on the Massive Multitask Language Understanding (MMLU) dataset is 63.4% as indicated in Table 9.

If you refer to the detailed results provided in Table 16, the average accuracy of the LLaMA 65B model specifically across all domains for the 5-shot setting is 63.4%.

  • query: LLaMA 7B zero-shot performance on PIQA, SIQA, BoolQ dataset

answer:

1
2
3
4
5
The zero-shot performance of the LLaMA 7B model on various Common Sense Reasoning tasks, including PIQA, SIQA, and BoolQ, is presented in the last image that contains multiple tables. This particular performance can be found in the "LLaMA 7B" row of the "0-shot" column across different datasets.

For PIQA, the zero-shot performance of LLaMA 7B is 76.5. For SIQA, it is 79.8. For BoolQ, the model scores 48.9.

These numbers indicate how well the LLaMA 7B model can perform on these specific tasks without any fine-tuning or additional shots, relying solely on its pre-existing training and capabilities to infer answers.

  • query: pretraining data of LLaMA and their prop

answer:

1
2
3
4
5
6
7
8
9
10
11
The pre-training data of LLaMA and their proportions from the various sources are as follows:

1. CommonCrawl: 67.0%
2. C4: 15.0%
3. Github: 4.5%
4. Wikipedia: 4.5%
5. Books: 4.5%
6. ArXiv: 2.5%
7. StackExchange: 2.0%

These sources were used for pre-training LLaMA models, with CommonCrawl being the largest proportion of the data used. The percentages indicate how much of each source was included in the mix of pre-training data.

接下来是基于纯文本(答案来自文本而不是表格、图片)的问答:

  • query: What are mathematical reasoning benchmarks in this paper?

answer:

1
The mathematical reasoning benchmarks mentioned in the text are MATH and GSM8k. The MATH dataset contains 12K middle school and high school mathematics problems, while the GSM8k is a set of middle school mathematical problems. The paper discusses evaluating models on these benchmarks and comparing model performance, specifically mentioning LLaMA model variants and their performance relative to other models such as PaLM and Minerva.

  • query: what is the use of TruthfulQA?

answer:

1
TruthfulQA is a benchmark designed to measure the truthfulness of a model's responses. Specifically, it evaluates the ability of a model to identify when a claim is true in the context of "literal truth about the real world" rather than claims that may only be considered true within a belief system or tradition. This means that TruthfulQA aims to help determine how effectively a model can distinguish valid information from misinformation or false claims, which can be indicative of the model's potential to generate reliable or misleading content. It is used to assess the risks of models in generating misinformation and to improve the design of models to mitigate these risks. The benchmark typically includes a set of questions that a language model must answer, and the model's responses are evaluated for both truthfulness and informativeness.

由此可见,这样的多模态RAG系统能够支持PDF文档的文本、表格、图片问答。

总结

当然,以上仅是笔者关于多模态RAG的一个实现思路的演示,其效果还有待提升,尤其是在数据量庞大的情境下。现阶段,关于多模态RAG的思考大多数仅停留在理论或者简单实验阶段(比如LlamaIndex就给出了这样的例子),笔者的这篇文章是关于多模态RAG的一次尝试,希望能给大家带来一些启发~

本文中的Python代码均已开源至Github,网址为:https://github.com/percent4/pdf-llm_series/tree/main/multi-modal-rag

参考文章

  1. Visualized BGE: https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/visual
  2. Multi-modal in LlamaIndex: https://docs.llamaindex.ai/en/stable/use_cases/multimodal/
  3. NLP(八十九)PDF文档智能问答入门
  4. NLP(九十一)PDF表格问答

NLP(九十八)基于PDF文档的多模态问答
https://percent4.github.io/NLP(九十八)基于PDF文档的多模态问答/
作者
Jclian91
发布于
2024年6月19日
许可协议