transformers模块中的模型推理流式输出

本文将会介绍如何在transformers模块中实现模型推理的流式输出。

transformers模块在进行模型推理的时候,可使用自带的Streaming方法进行流式输出。当然,我们也可以使用模型部署框架来更好地支持模型推理的流式输出,比如vLLM, TGI等。

下面将会详细具体如何在transformers模块中对模型推理进行流式输出。

流式输出

  • 使用transformers模块自带的TextStreamer,在终端中进行流式输出

示例代码:

1
2
3
4
5
6
7
8
9
10
11
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
message = [{"role": "user", "content": "沈阳一共有几条地铁?"}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
print(conversion)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextStreamer(tokenizer)
model.generate(**encoding, max_new_tokens=500, temperature=0.2, do_sample=True, streamer=streamer, pad_token_id=tokenizer.eos_token_id)

输出结果:

1
截至2022年,沈阳市已经开通运营了4条地铁线路,分别是1号线、2号线、3号线和9号线。未来沈阳地铁建设规划还在进行中,预计线路将进一步延伸和扩展。
  • 使用transformers模块自带的TextIterateStreamer,自定义流式输出

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
message = [{"role": "user", "content": "沈阳一共有几条地铁?"}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
print(conversion)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(encoding, streamer=streamer, max_new_tokens=100, do_sample=True, temperature=0.2)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

generated_text = ""
for new_text in streamer:
output = new_text.replace(conversion, '')
if output:
print(output)

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
截至
2022
,沈阳

已经
开通
运营

4
地铁
线路
,分别为
1号线
2号线
3号线

9号线
。未来
规划

还有


线路

建设

,如
6号线
8号线

。<|im_end|>
  • 使用gradio实现Web版的流式输出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# -*- coding: utf-8 -*-
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

model_id = "./models/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

def question_answer(query):
message = [{"role": "user", "content": query}]
conversion = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
encoding = tokenizer(conversion, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(encoding, streamer=streamer, max_new_tokens=1000, do_sample=True, temperature=0.2)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

generate_text = ''
for new_text in streamer:
output = new_text.replace(conversion, '')
if output:
generate_text += output
yield generate_text


demo = gr.Interface(
fn=question_answer,
inputs=gr.Textbox(lines=3, placeholder="your question...", label="Question"),
outputs="text",
)

demo.launch(server_name="0.0.0.0", server_port=50079, share=True)

效果如下图:

https://mp.weixin.qq.com/s?__biz=MzU2NTYyMDk5MQ==&mid=2247486716&idx=1&sn=e8f6099d48682299485bb04715fd10ca&chksm=fcb9b56ccbce3c7a2cb21dad93c030c55256222e7ea6cf4f80d44c096d7ffcdc54440dd65eb1&token=321761101&lang=zh_CN#rd

参考文献

  1. Streamers in transformers: https://huggingface.co/docs/transformers/v4.39.3/en/internal/generation_utils#transformers.TextStreamer

欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。


transformers模块中的模型推理流式输出
https://percent4.github.io/transformers模块中的模型推理流式输出/
作者
Jclian91
发布于
2024年5月3日
许可协议