关于LangChain入门,读者可参考文章NLP(五十六)LangChain入门
。
本文将会介绍LangChain中的重连机制,并尝试给出定制化重连方案。
本文以LangChain中的对话功能(ChatOpenAI
)为例。
LangChain中的重连机制
查看LangChain中对话功能(ChatOpenAI
)的重连机制(retry),其源代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class ChatOpenAI (BaseChatModel ): ... def _create_retry_decorator (self ) -> Callable [[Any ], Any ]: import openai min_seconds = 1 max_seconds = 60 return retry( reraise=True , stop=stop_after_attempt(self.max_retries), wait=wait_exponential(multiplier=1 , min =min_seconds, max =max_seconds), retry=( retry_if_exception_type(openai.error.Timeout) | retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai.error.ServiceUnavailableError) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) def completion_with_retry (self, **kwargs: Any ) -> Any : """Use tenacity to retry the completion call.""" retry_decorator = self._create_retry_decorator() @retry_decorator def _completion_with_retry (**kwargs: Any ) -> Any : return self.client.create(**kwargs) return _completion_with_retry(**kwargs)
可以看到,其编码方式为硬编码(hardcore),采用tenacity
模块实现重连机制,对于支持的报错情形,比如openai.error.Timeout, openai.error.APIError
等,会尝试重连,最小等待时间为1s,最大等待时间为60s,每次重连等待时间会乘以2。
简单重连
我们尝试用一个错误的OpenAI key进行对话,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 from langchain.chat_models import ChatOpenAIdef chat_bot (input_text: str ): llm = ChatOpenAI(temperature=0 , model_name="gpt-3.5-turbo" , openai_api_key="sk-xxx" , max_retries=5 ) return llm.predict(input_text)if __name__ == '__main__' : text = '中国的首都是哪里?' print (chat_bot(text))
尽管我们在代码中设置了重连最大次数(max_retries
),代码运行时会直接报错,不会重连,原因是LangChain中的对话功能重连机制没有支持openai.error.AuthenticationError
。输出结果如下:
1 openai.error.AuthenticationError: Incorrect API key provided: sk-xxx. You can find your API key at https://platform.openai.com/account/api-keys.
此时,我们尝试在源代码的基础上做简单的定制,使得其支持openai.error.AuthenticationError
错误类型,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 import openaifrom typing import Callable , Any from tenacity import ( before_sleep_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, )from langchain.chat_models import ChatOpenAIimport logging logger = logging.getLogger(__name__)class MyChatOpenAI (ChatOpenAI ): def _create_retry_decorator (self ) -> Callable [[Any ], Any ]: min_seconds = 1 max_seconds = 60 return retry( reraise=True , stop=stop_after_attempt(self.max_retries), wait=wait_exponential(multiplier=1 , min =min_seconds, max =max_seconds), retry=( retry_if_exception_type(openai.error.Timeout) | retry_if_exception_type(openai.error.APIError) | retry_if_exception_type(openai.error.APIConnectionError) | retry_if_exception_type(openai.error.RateLimitError) | retry_if_exception_type(openai.error.ServiceUnavailableError) | retry_if_exception_type(openai.error.AuthenticationError) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) def completion_with_retry (self, **kwargs: Any ) -> Any : """Use tenacity to retry the completion call.""" retry_decorator = self._create_retry_decorator() @retry_decorator def _completion_with_retry (**kwargs: Any ) -> Any : return self.client.create(**kwargs) return _completion_with_retry(**kwargs)def chat_bot (input_text: str ): llm = MyChatOpenAI(temperature=0 , model_name="gpt-3.5-turbo" , openai_api_key="sk-xxx" , max_retries=5 ) return llm.predict(input_text)if __name__ == '__main__' : text = '中国的首都是哪里?' print (chat_bot(text))
分析上述代码,我们在继承ChatOpenAI类的基础上重新创建MyChatOpenAI类,在_create_retry_decorator中的重连错误情形中加入了openai.error.AuthenticationError
错误类型,此时代码输出结果如下:
1 2 3 4 5 6 7 Retrying __main__.MyChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised AuthenticationError: Incorrect API key provided: sk-xxx. You can find your API key at https://platform.openai.com/account/api-keys.. Retrying __main__.MyChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 2.0 seconds as it raised AuthenticationError: Incorrect API key provided: sk-xxx. You can find your API key at https://platform.openai.com/account/api-keys.. Retrying __main__.MyChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised AuthenticationError: Incorrect API key provided: sk-xxx. You can find your API key at https://platform.openai.com/account/api-keys.. Retrying __main__.MyChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 8.0 seconds as it raised AuthenticationError: Incorrect API key provided: sk-xxx. You can find your API key at https://platform.openai.com/account/api-keys.. Traceback (most recent call last): ...... openai.error.AuthenticationError: Incorrect API key provided: sk-xxx. You can find your API key at https://platform.openai.com/account/api-keys.
从输出结果中,我们可以看到,该代码确实对openai.error.AuthenticationError
错误类型进行了重连,按照源代码的方式进行重连,一共尝试了5次重连,每次重连等待时间是上一次的两倍。
定制化重连
LangChain中的重连机制也支持定制化。
假设我们的使用场景:某个OpenAI
key在调用过程中失效了,那么在重连时希望能快速切换至某个能正常使用的OpenAI
key,以下为示例代码(仅需要修改completion_with_retry
函数):
1 2 3 4 5 6 7 8 9 10 11 def completion_with_retry (self, **kwargs: Any ) -> Any : """Use tenacity to retry the completion call.""" retry_decorator = self._create_retry_decorator() @retry_decorator def _completion_with_retry (**kwargs: Any ) -> Any : kwargs['api_key' ] = 'right openai key' return self.client.create(**kwargs) return _completion_with_retry(**kwargs)
此时就能进行正常的对话功能了。
总结
本文介绍了LangChain中的重连机制,并尝试给出定制化重连方案,希望能对读者有所帮助。
笔者的个人博客网址为:https://percent4.github.io/
,欢迎大家访问~
欢迎关注我的公众号
NLP奇幻之旅 ,原创技术文章第一时间推送。
欢迎关注我的知识星球“自然语言处理奇幻之旅 ”,笔者正在努力构建自己的技术社区。