我有一个简单的 RAG 应用程序,无法弄清楚如何通过流式传输存储内存。
save_context
应该成为链条的一部分吗?或者我必须使用一些回调来处理它?
示例末尾是
answer_chain
,其中最后一步被跳过。我相信最后应该有一些东西,但我不知道是什么。我想在流媒体完成时运行回调。
此外,我将链分为两步,因为当有一个大的流链时,它会将文档等发送到stout,这没有意义,我只想要消息。用两个独立的链来处理它是正确的方法吗?
有什么想法吗?
import uuid
from typing import Iterator
import dotenv
from langchain_core.messages import get_buffer_string
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, format_document
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableParallel
from langchain_core.runnables.utils import Output
from document_index.vector import get_retriever
from operator import itemgetter
from memory import get_memory
from model import get_model
dotenv.load_dotenv()
model = get_model()
condense_question_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)
initial_prompt = """
You are helpful AI assistant.
Answer the question based only on the context below.
### Context start ###
{context}
### Context end ###
Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(initial_prompt)
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
retriever = get_retriever()
def _get_memory_with_session_id(session_id):
return get_memory(session_id)
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def search(session_id, query) -> Iterator[Output]:
memory = _get_memory_with_session_id(session_id)
def _save_context(inputs, answer):
memory.save_context(inputs, {"answer": answer})
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
)
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
}
| CONDENSE_QUESTION_PROMPT
| model
| StrOutputParser()
}
retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever,
"question": lambda x: x["standalone_question"],
}
preparation_chain = loaded_memory | standalone_question | retrieved_documents
memory.load_memory_variables({})
inputs = {"question": query}
docs = preparation_chain.invoke(inputs)
answer_chain = (
{"docs": RunnablePassthrough()}
| {
"context": lambda x: _combine_documents(x["docs"]),
"question": itemgetter("question"),
}
| ANSWER_PROMPT
| model
| StrOutputParser()
# | RunnableLambda(_save_context, ????query_argument, ????MODEL_ANSWER)
)
return answer_chain.stream(docs)
if __name__ == "__main__":
session_id = str(uuid.uuid4())
query = "Where to buy beer?"
for result in search(session_id, query):
print(result, end="")
此案例存在一个“未解决的问题”,您可能需要找到解决方法。我提出了自己的解决方案,其中包括创建一个自定义的可运行 lambda,它让输入通过流并使用收集的输出调用 lambda。在那里你可以保存内存上下文。 由于某种原因,内存状态似乎无法在链调用之间保留。这就是我想出的,它是早期的蹩脚东西,但有助于到达你想要的地方:
class RunnableCollector(RunnableLambda):
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
final: Input
got_first_val = False
for ichunk in input:
yield cast(Output, ichunk)
if not got_first_val:
final = ichunk
got_first_val = True
else:
try:
final = final + ichunk # type: ignore[operator]
except TypeError:
final = ichunk
call_func_with_variable_args(
self.func, cast(Input, final), config, run_manager, **kwargs
)
class RunnableFilter(RunnableLambda):
def __init__(self, filter: Callable[[Input], bool], **kwargs) -> None:
super().__init__(func=lambda _: None, **kwargs)
self.filter = filter
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
for ichunk in input:
if self.filter(ichunk):
yield ichunk
class RunnableMap(RunnableLambda):
def __init__(self, mapping: Callable[[Input], Output], **kwargs) -> None:
super().__init__(func=lambda _: None, **kwargs)
self.mapping = mapping
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
for ichunk in input:
yield self.mapping(ichunk)
def setup_chain(
model="gpt-3.5-turbo",
) -> RunnableSerializable:
memory = ConversationBufferMemory(memory_key="history", return_messages=True)
def save_into_mem(input_output: Dict[str, Any]):
message = input_output.pop("output")
memory.save_context(input_output, {"output": message.content})
print("\n\n2.\t", memory.load_memory_variables({}))
chain = (
RunnablePassthrough.assign(
history=RunnableLambda(memory.load_memory_variables)
| itemgetter("history"),
dummy=RunnableLambda(lambda x: print("input?", x)),
)
| ChatPromptTemplate.from_messages(
[
("system", SYSTEM_MESSAGE),
MessagesPlaceholder(variable_name="history"),
("user", "{input}"),
]
)
| ChatOpenAI(
model=model,
temperature=0,
streaming=True,
)
)
chain = (
RunnablePassthrough.assign(
output=chain,
)
| RunnableFilter(
filter=lambda chunk: type(chunk) is AddableDict
and "output" in chunk
or "input" in chunk
)
| RunnableCollector(save_into_mem)
| RunnableFilter(
filter=lambda chunk: type(chunk) is AddableDict
and "output" in chunk
and not "input" in "chunk"
)
| RunnableMap(mapping=lambda chunk: cast(Output, chunk["output"]))
)
return chain