表格检测与识别入门

本文将会介绍如何使用Mircosoft开源的表格检测模型table-transformer-detection来实现表格检测与入门。

几年前,笔者在学习CV的时候,曾接触过OpenCV这个工具,当时也研究过表格识别,还曾写过一篇文章: 如何识别图片中的表格数据 .

这几年笔者一直在做NLP方向的工作,由于最近研究多模态的缘故,再次接触表格检测与识别,这已是5年之后,顿时感觉光阴如梭。

本文将会使用Microsoft开源的表格检测模型table-transformer-detection来实现表格检测与入门

以下将分三部分进行介绍:

  • 表格检测:检测图片或PDF文件中的表格所在的区域
  • 表格结构识别:对于检测后的表格区域,再详细识别表格的区域,即表格的行、列,表头所在的位置,进一步得到单元格的位置
  • 表格数据提取: 在表格结构的基础上,借助OCR可得到每个单元格内的文本,从而获得整个表格数据

表格检测

使用Microsoft开源的表格检测模型microsoft/table-transformer-detectio,可以从图片或PDF文件中检测出表格所在的区域。

对于PDF文件,可将每页文档转化为图片进行检测,转化过程在文章NLP(八十九)PDF文档智能问答入门中已有介绍。

表格检测的示例Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
import torch
from PIL import Image

file_path = "./table_detection/images/demo.jpg"
image = Image.open(file_path).convert("RGB")
file_name = file_path.split('/')[-1].split('.')[0]


image_processor = AutoImageProcessor.from_pretrained("./models/table-transformer-detection")
model = TableTransformerForObjectDetection.from_pretrained("./models/table-transformer-detection")

inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)

# convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]

i = 0
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)

region = image.crop(box) #检测
region.save(f'./table_images/{file_name}_{i}.jpg')
i += 1

输入图片demo.jpg如下:

demo.jpg

检测到的表格(两个)图片如下:

demo_0.png
demo_1.png

输入的论文页面图片如下:

llama-pdf-10.png

检测到的表格(两个)图片如下:

llama-pdf-10_0.png
llama-pdf-10_1.png

可以看到,该模型的表格检测效果还是很棒的。

表格结构识别

对于上一步识别到的表格区域,再利用Microsoft开源的表格结构识别模型microsoft/table-transformer-structure-recognition-v1.1-all进行表格结构识别。

表格结构识别(只输出表头)的示例Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from PIL import Image
from transformers import DetrFeatureExtractor
from transformers import AutoImageProcessor, TableTransformerForObjectDetection

feature_extractor = DetrFeatureExtractor()

file_path = "./table_detection/table_images/demo_0.png"
image = Image.open(file_path).convert("RGB")

encoding = feature_extractor(image, return_tensors="pt")
model = TableTransformerForObjectDetection.from_pretrained("/data-ai/usr/lmj/models/table-transformer-structure-recognition-v1.1-all")
print(model.config.id2label)
# {0: 'table', 1: 'table column', 2: 'table row', 3: 'table column header', 4: 'table projected row header', 5: 'table spanning cell'}

with torch.no_grad():
outputs = model(**encoding)

target_sizes = [image.size[::-1]]
results = feature_extractor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
print(results)

columns_box_list = [results['boxes'][i].tolist() for i in range(len(results['boxes'])) if results['labels'][i].item()==3]
crop_image = image.crop(columns_box_list[0])
crop_image.save('header.png')

对于第1张输出表格图片,其表头如下:

demo_0.png的表头

上述的代码仅作为演示,实际上,借助表格结构识别模型,我们可获取整个表格的结构,包括表头、行、列、单元格所在的位置。

表格数据提取

在获得整个表格结构的基础上,再借助OCR工具,不难获取每个单元格内的文本,从而得到整个表格数据,这是再自然不过的想法。这样,我们借助开源模型,也能自己创建一个表格识别的小工具啦!不过,囿于效果原因,暂时还只能支持整行整列的表格(即不含合并单元格等复杂的表格)数据提取。

使用gradio构建一个表格数据提取工具,Python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-
import cv2
import json
import base64
import pandas as pd
import gradio as gr
from PIL import Image
import requests
from urllib.request import urlretrieve
from uuid import uuid4
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
import torch
from transformers import DetrFeatureExtractor
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
from paddleocr import PaddleOCR


image_processor = AutoImageProcessor.from_pretrained("./models/table-transformer-detection")
detect_model = TableTransformerForObjectDetection.from_pretrained("./models/table-transformer-detection")
structure_model = TableTransformerForObjectDetection.from_pretrained("./models/table-transformer-structure-recognition-v1.1-all")
print(structure_model.config.id2label)
# {0: 'table', 1: 'table column', 2: 'table row', 3: 'table column header', 4: 'table projected row header', 5: 'table spanning cell'}
feature_extractor = DetrFeatureExtractor()

ocr = PaddleOCR(use_angle_cls=True, lang="ch")


def paddle_ocr(image_path):
result = ocr.ocr(image_path, cls=True)
ocr_result = []
for idx in range(len(result)):
res = result[idx]
if res:
for line in res:
print(line)
ocr_result.append(line[1][0])
return "".join(ocr_result)


def table_detect(image_box, image_url):
if not image_url:
file_name = str(uuid4())
image = Image.fromarray(image_box).convert('RGB')
else:
image_path = f"./images/{uuid4()}.png"
file_name = image_path.split('/')[-1].split('.')[0]
urlretrieve(image_url, image_path)
image = Image.open(image_path).convert('RGB')
inputs = image_processor(images=image, return_tensors="pt")
outputs = detect_model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]

i = 0
output_images = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {detect_model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)

region = image.crop(box) #检测
output_image_path = f'./table_images/{file_name}_{i}.jpg'
region.save(output_image_path)
output_images.append(output_image_path)
i += 1
return output_images


def table_ocr(output_images, image_index):
output_image = output_images[int(image_index)][0]
image = Image.open(output_image).convert("RGB")
encoding = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = structure_model(**encoding)

target_sizes = [image.size[::-1]]
results = feature_extractor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
print(results)
# get column and row
columns = []
rows = []
for i in range(len(results['boxes'])):
_id = results['labels'][i].item()
if _id == 1:
columns.append(results['boxes'][i].tolist())
elif _id == 2:
rows.append(results['boxes'][i].tolist())

sorted_columns = sorted(columns, key=lambda x: x[0])
sorted_rows = sorted(rows, key=lambda x: x[1])
# ocr by cell
ocr_results = []
for row in sorted_rows:
row_result = []
for col in sorted_columns:
rect = [col[0], row[1], col[2], row[3]]
crop_image = image.crop(rect)
image_path = 'cell.png'
crop_image.save(image_path)
row_result.append(paddle_ocr(image_path=image_path))
print(row_result)
ocr_results.append(row_result)

print(ocr_results)
return ocr_results

if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_box = gr.Image()
image_urls = gr.TextArea(lines=1, placeholder="Enter image url", label="Images")
image_index = gr.TextArea(lines=1, placeholder="Image Number", label="No")
with gr.Column():
gallery = gr.Gallery(label="Tables", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
detect = gr.Button("Table Detection")
submit = gr.Button("Table OCR")
ocr_outputs=gr.DataFrame(label='Table',
interactive=True,
wrap=True)
detect.click(fn=table_detect,
inputs=[image_box, image_urls],
outputs=gallery)
submit.click(fn=table_ocr,
inputs=[gallery, image_index],
outputs=ocr_outputs)
demo.launch(server_name="0.0.0.0", server_port=50074)

其中OCR工具采用PaddleOCR,内置模型为ch_PP-OCRv4,文字识别效果还不错。启动程序,界面如下:

网页版表格识别界面

该界面支持自己上传文档或者输入图片所在网址,即可进行图片内的表格检测,检测完后,可进行第n张图片的数据提取。

由于篇幅原因,本文给出效果较好的几个例子:

  • 输入图片网址1
例子1
  • 输入图片网址2
例子2
  • 输入图片网址3
例子3
  • 自己上传图片
例子4

总结

本文主要介绍了如何使用Microsoft开源的表格检测与结构检测开源模型来实现表格检测、结构检测与数据提取,并构建了一个web应用,方便直观地查看表格识别与数据提取的结果。

本文源代码已在文中展示,暂未放至Github,后续将开源。

参考文献

  1. microsoft/table-transformer-detection: https://huggingface.co/microsoft/table-transformer-detection
  2. Multi-Modal on PDF’s with tables: https://docs.llamaindex.ai/en/v0.10.20/examples/multi_modal/multi_modal_pdf_tables.html

欢迎关注我的公众号NLP奇幻之旅,原创技术文章第一时间推送。

欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。


表格检测与识别入门
https://percent4.github.io/表格检测与识别入门/
作者
Jclian91
发布于
2024年4月3日
许可协议