如何向 SequentialChain 中的 MultiQueryRetriever 添加内存?

问题描述 投票:0回答:1

我正在尝试创建一个聊天机器人,该机器人了解使用streamlit保存在松果中的数据。我已经可以提出问题并获得答案,但我很难设置它的内存以添加我之前消息的上下文。

这是我的代码的一部分

embeddings = OpenAIEmbeddings()
pinecone.init(api_key=os.environ.get('PINECONE_API_KEY'), environment=os.environ.get('PINECONE_ENV'))
vectorstore = Pinecone.from_existing_index(index_name, embeddings)
llm=ChatOpenAI(temperature=0.9, model_name = "gpt-4-0125-preview", max_tokens=4000)
retriever = MultiQueryRetriever.from_llm(
retriever=vectorstore.as_retriever(), llm=llm
)
user_prompt = st.chat_input()

QA_PROMPT = PromptTemplate(
    input_variables=["query", "contexts","chat_history"],
    template="""You are a professional grant proposal writer. Your task is to create grant proposal content
    for a particular grant proposal topic that the user wants. You can use the context provided to make the grant proposal content
    detailed, comprehensive and accurate. If the question or queries cannot be answered using the information
    provided say use your own knowledge to answer the query.

    chat_history: {chat_history}
    Contexts:
    {contexts}

    Question: {query}
    AI:"""
    )   
memory = ConversationSummaryBufferMemory(
    llm=llm,
    memory_key="chat_history",
    input_key="query",
    max_token=5000 # after 5000 token, summary of the conversation will be created and stored in moving_summary_buffer
    )
qa_chain = load_qa_chain(llm=llm, chain_type="stuff", memory=memory, prompt=QA_PROMPT)

def retrieval_transform(inputs: dict) -> dict:
    docs = retriever.get_relevant_documents(query=inputs["user_prompt"])
    docs = [d.page_content for d in docs]
    docs_dict = {
        "query": inputs["user_prompt"],
        "contexts": "\n---\n".join(docs)
    }
    return docs_dict

retrieval_chain = TransformChain(
    input_variables=["user_prompt"],
    output_variables=["query", "contexts"],
    transform=retrieval_transform
)

rag_chain = SequentialChain(
    chains=[retrieval_chain, qa_chain],
    input_variables=["user_prompt"],  # we need to name differently to output "query"
    output_variables=["query", "contexts", "text"],
    verbose=True
)
if "messages" not in st.session_state.keys():
    st.session_state.messages = [
        {"role": "assistant", "content": "Hello there, enter your question for grant proposal assistance."}
    ]
if "requests" not in st.session_state:
    st.session_state["requests"] = []
if "responses" not in st.session_state:
    st.session_state["responses"] = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

if user_prompt is not None:
    st.session_state.messages.append({"role": "user", "content": user_prompt})
    with st.chat_message("user"):
        st.write(user_prompt)

if st.session_state.messages[-1]["role"] != "assistant":
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            out = rag_chain({"user_prompt": user_prompt,
                             "contexts": docs_dict["contexts"]})
            ai_response = out["text"]
            st.write(ai_response)
            st.session_state["messages"].append({"role": "assistant", "content": ai_response})

我尝试将内存合并到 load_qa_chain 但我收到 input_variable 错误

python openai-api langchain
1个回答
0
投票

你解决了吗?我也有同样的问题

© www.soinside.com 2019 - 2024. All rights reserved.