토니의 연습장

RAG - CacheBackedEmbeddings 본문

언어 AI (NLP)/LLM & RAG & Agent

RAG - CacheBackedEmbeddings

bellmake 2025. 8. 1. 15:27

CacheBackedEmbeddings

Embeddings는 재계산을 피하기 위해 저장되거나 일시적으로 캐시될 수 있습니다.

Embeddings를 캐싱하는 것은 CacheBackedEmbeddings를 사용하여 수행될 수 있습니다. 캐시 지원 embedder는 embeddings를 키-값 저장소에 캐싱하는 embedder 주변에 래퍼입니다. 텍스트는 해시되고 이 해시는 캐시에서 키로 사용됩니다.

CacheBackedEmbeddings를 초기화하는 주요 지원 방법은 from_bytes_store입니다. 이는 다음 매개변수를 받습니다:

  • underlying_embeddings: 임베딩을 위해 사용되는 embedder.
  • document_embedding_cache: 문서 임베딩을 캐싱하기 위한 ByteStore 중 하나.
  • namespace: (선택 사항, 기본값은 "") 문서 캐시를 위해 사용되는 네임스페이스. 이 네임스페이스는 다른 캐시와의 충돌을 피하기 위해 사용됩니다. 예를 들어, 사용된 임베딩 모델의 이름으로 설정하세요.

주의: 동일한 텍스트가 다른 임베딩 모델을 사용하여 임베딩될 때 충돌을 피하기 위해 namespace 매개변수를 설정하는 것이 중요합니다.

-> RAG 프로젝트에 적용 예시 - 약 32개의 전문 사양서 (Adaptive AUTOSAR) 임베딩에 10-20분 가량 소요되는데,
    OllamaEmbeddings 로 embedding 시에 Cache 저장/불러오기 활용하여 첫 임베딩 이후 실행시마다 시간 단축

 

class RagChatChain(BaseChain):
    """
    RAG 기반 대화형 체인
    """
    def __init__(self, model: str = "exaone-deep:32b", temperature: float = 0.3, system_prompt: Optional[str] = None, **kwargs):
        super().__init__(model, temperature, **kwargs)
        # self.system_prompt = "You are a Automotive Software Expert. Always answer in Korean. Your name is 'joseph'."
        self.system_prompt = "너는 차량 소프트웨어 분야 전문가야. 언제나 너의 생각과 답변 모두 반드시 항상 한글로 답해주고, 마지막에 출처 문서와 페이지 번호를 반드시 알려줘."
        if "file_paths" in kwargs:
            self.file_paths = kwargs.pop("file_paths")
        elif "file_path" in kwargs:
            self.file_paths = [kwargs.pop("file_path")]
        else:
            raise ValueError("file_path(s) is required")
        self.vectorstore = None

    def setup(self):
        if not self.file_paths:
            raise ValueError("file_paths is required")
        print("RagChatChain setup")
        # 문서 로딩 및 분할
        raw_docs = []
        for file_path in self.file_paths:
            loader = PDFPlumberLoader(file_path)
            docs_from_file = loader.load()
            for i, doc in enumerate(docs_from_file):
                doc.metadata["page_number"] = doc.metadata.get("page_number", i + 1)
                doc.metadata["source_file"] = file_path
            raw_docs.extend(docs_from_file)
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
        docs = text_splitter.split_documents(raw_docs)
        # 캐시 가능한 임베딩 모델 설정 및 FAISS 인덱스 생성/로드
        embeddings = OllamaEmbeddings(model="bge-m3")
        cache_dir = Path.cwd() / "data" / "embedding_cache"
        faiss_index = cache_dir / "index.faiss"
        pkl_index = cache_dir / "index.pkl"
        if faiss_index.exists() and pkl_index.exists():
            # 캐시된 인덱스 로드
            print(f"Loading cached FAISS index from {faiss_index}")
            self.vectorstore = FAISS.load_local(
                str(cache_dir),
                embeddings,
                allow_dangerous_deserialization=True
            )
        else:
            print(f"Creating new FAISS index in {cache_dir}")
            # 새 인덱스 생성 및 저장
            self.vectorstore = FAISS.from_documents(docs, embedding=embeddings)
            cache_dir.mkdir(parents=True, exist_ok=True)
            self.vectorstore.save_local(str(cache_dir))
        # prompt = load_prompt("prompts/rag-llama.yaml", encoding="utf-8")
        # 1) 시스템 프롬프트를 직접 주입하는 PromptTemplate
        prompt = ChatPromptTemplate.from_messages([
           ("system", self.system_prompt),
           ("user",   "Context:\n{context}\n\nQuestion:\n{question}")
        ])

        llm = ChatOllama(
            model=self.model,
            temperature=self.temperature,
            callback_manager=callback_manager,
            streaming=True
        )
        def combine_messages(input_dict):
            messages = input_dict["messages"]
            last_user_message = next((msg for msg in reversed(messages) if isinstance(msg, HumanMessage)), None)
            if last_user_message:
                context_docs = self.vectorstore.as_retriever().get_relevant_documents(last_user_message.content)
                return {
                    "question": last_user_message.content,
                    "context": format_docs(context_docs),
                }
            else:
                return {"question": "", "context": ""}
        chain = RunnablePassthrough() | combine_messages | prompt | llm | StrOutputParser()
        return chain

 

InmemoryByteStore 사용 (비영구적)

다른 방식으로 비영구적인 cache 방식으로서 ByteStore를 사용하기 위해서는 CacheBackedEmbeddings를 생성할 때 해당 ByteStore를 사용하면 됩니다.

 

'언어 AI (NLP) > LLM & RAG & Agent' 카테고리의 다른 글

get_batch( )  (0) 2025.08.26
Unsloth  (0) 2025.08.20
프로젝트 내에서 data 경로 지정 관련  (0) 2025.07.29
RAG - reranker  (0) 2025.07.23
Azure AI Search 활용 예시  (0) 2025.05.19