使用 langchain RetrievalQA 解决内存 ram 问题

问题描述 投票:0回答:1

嗨,我正在使用 chainlit、langchain 和 FAISS 制作一个具有多个向量数据库的 RAG 系统。几天前,我看到 RAG 使用了大量内存,例如 10GB,所以我想修复它,但我不知道 langchain 是否有 close 方法或类似的方法,然后我可以用来关闭检索QA 的过程.

如果有人知道如何解决这个问题,请告诉我,我将分享我的代码。

from typing import Dict, Optional
from langchain import PromptTemplate
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
import chainlit as cl
from langchain_groq import ChatGroq
from dotenv import load_dotenv
import os

from google.oauth2 import id_token
from google.auth.transport import requests
from utils.Retrival import Retrival; 



load_dotenv()
temp = 0
groq_api = os.getenv("GROQ_KEY")



db_faiss = ""
GROQ_MODEL = "llama-3.1-70b-versatile"

#DB_FAISS_PATH = "./vectorestore_sinHU"
custom_prompt_template ="""


Context: {context}
Question: {question}

Recuerda no puedes creear una respuesta, solamnete te pudes basar en el contexto.
la mejor respuesta:
"""




def set_custom_prompt():
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context','question'])
    return prompt

def retrieval_qa_chain(llm, prompt, db, kw = 80):
    qa_chain = RetrievalQA.from_chain_type(llm=llm,
                                       chain_type='stuff',
                                       retriever=db.as_retriever(search_kwargs={'k': kw}),
                                       return_source_documents=True,
                                       chain_type_kwargs={'prompt': prompt}
                                       )
    return qa_chain

def qa_bot(db,kw):
    embeddings = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1",
                                       model_kwargs={'device': 'cpu'})
    db = FAISS.load_local(db, embeddings,allow_dangerous_deserialization=True)
    llm = ChatGroq(groq_api_key = groq_api,temperature=temp, model_name=GROQ_MODEL)
    qa_prompt = set_custom_prompt()
    qa = retrieval_qa_chain(llm, qa_prompt, db,kw)

    return qa  


retrivalAux = Retrival() 

@cl.on_chat_start
async def start():
    global retrivalAux
    
    chat_profile = cl.user_session.get('chat_profile')

    msg = cl.Message(content="Cargando chatbot por favor espere...")
    await msg.send()

    match chat_profile:
        case "chat1":           
            msg.content = "Hola."
            db_faiss = ""
            if retrivalAux.getNombre() != "chat1" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat1")                retrivalAux.setRetrival(retrivalQA=qa_bot(db="./vectorstores/chat1",kw=20))

            cl.user_session.set("chain", retrivalAux.getRetriever())
            
        case "chat2":

            msg.content = "Hola"
            if retrivalAux.getNombre() != "chat2" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat2")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
            
        case "chat3":
            cl.user_session.set("chain", None)
            msg.content = "Hola"
            db_faiss =""
            if retrivalAux.getNombre() != "chat3" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre=cimabot_SISCRED_ALU.name)
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
            
            
        case "chat4":
            
            msg.content = "Hola"
            db_faiss ="./vectorstores/chat4"
            if retrivalAux.getNombre() != "chat4" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat4")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
        case "chat5":
            cl.user_session.set("chain", None)
            msg.content = "hi"
            #chain =  qa_bot(db="./vectorstores/chat5",kw=20)
            db_faiss ="./vectorstores/SIMA/alumno"
            if retrivalAux.getNombre() != chat5 or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre=chat5)
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
        case "chat6":

            msg.content = "hi"
            db_faiss = "./vectorstores/chat6"
            if retrivalAux.getNombre() != "chat6" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat6")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
        case "chat7":

            msg.content = "Hola"
            db_faiss = "./vectorstores/chat7"

            if retrivalAux.getNombre() != "chat7" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat7")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
    await msg.update()

    

@cl.on_message
async def main(message: cl.Message):

    chain = cl.user_session.get("chain")

    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    cb.answer_reached = True
    
    res = await chain.acall(message.content, callbacks=[cb])
    answer = res["result"]

    
    await cl.Message(content=answer).send()
    
@cl.on_chat_end
def end():
    print("--------------adios--------")
    global retrivalAux
    cl.user_session.set("chain", None)
    
    retrivalAux.destroy()
    
    
if __name__ == "__main__":
    from chainlit.cli import run_chainlit
    run_chainlit(__file__)

检索.py



class Retrival:
    
    def __init__(self, nombre=None,retrivalQA = None):
        
        self.retriever = retrivalQA;
        self.nombre = nombre;
        print(nombre)
    
    def getRetriever(self):
        return self.retriever
    
    def getNombre(self):
        return self.nombre
    
    def setNombre(self, nombre=None):
        del self.nombre
        self.nombre = nombre
    
    def setRetrival(self,retrivalQA = None):
        self.retriever = retrivalQA
    
    
    
    def destroy(self):
        if self.retriever:
            print("Memory--------------")
            print(self.retriever.memory)
        
        
        del self.retriever
        del self.nombre
        
        self.nombre = None
        self.retriever = None
        print(self.retriever)

我想要一个在使用RetrievalQA后删除它的过程的解决方案。

python-3.x langchain rag chainlit retrievalqa
1个回答
0
投票

我解决了这个问题,我移出了 qa_bot 函数的嵌入,然后我修复了,因为当我更改聊天配置文件时,它会创建另一个嵌入,但所有聊天配置文件都使用相同的嵌入 XD,所以我移出并修复了它。

© www.soinside.com 2019 - 2024. All rights reserved.