code-travail/rag_build_db.py
2024-01-03 22:40:11 +01:00

102 lines
2.8 KiB
Python

# Set env var OPENAI_API_KEY or load from a .env file
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import dotenv
dotenv.load_dotenv()
# Load documents
# from langchain.document_loaders import WebBaseLoader
# loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
import os
from langchain.schema import Document
from lxml import etree
text_file = os.path.abspath(".") + "/sources/code-travail.xml"
with open(text_file, 'rb') as xml_file:
xml_data = xml_file.read()
splits = []
root = etree.fromstring(xml_data)
# Iterate through the <entry> elements and extract information
for entry in root.findall('entry'):
metadata = entry.find('metadata')
page_content = entry.find('page_content').text
splits.append(Document(
page_content=page_content,
metadata={
"partie": metadata.find('partie').text,
"livre": metadata.find('livre').text,
"titre": metadata.find('titre').text,
"chapitre": metadata.find('chapitre').text,
"section": metadata.find('section').text,
"article": metadata.find('article').text
}
))
# Split documents
#
# from langchain.text_splitter import RecursiveCharacterTextSplitter
#
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
# splits = text_splitter.split_documents(loader.load())
#
# print("splits length: {}".format( len(splits)))
# print(splits)
# Embed and store splits
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
# vectorstore = Chroma.from_documents(persist_directory="./chroma_db", documents=splits, embedding=OpenAIEmbeddings())
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=OpenAIEmbeddings())
import time
for doc in splits:
print(doc.metadata)
try:
vectorstore.add_documents([doc])
except Exception as exc:
print("ERR : {}".format(exc))
time.sleep(0.001)
retriever = vectorstore.as_retriever()
print("Got embeddings")
# Prompt
# https://smith.langchain.com/hub/rlm/rag-prompt
from langchain import hub
rag_prompt = hub.pull("rlm/rag-prompt")
print("Got rag")
# LLM
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
print("Got llm")
# RAG chain
from langchain.schema.runnable import RunnablePassthrough
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| llm
)
res = rag_chain.invoke("Quel article de loi concerne les salariés étrangers?")
print(res)
# AIMessage(content='Task decomposition is the process of breaking down a task into smaller subgoals or steps. It can be done using simple prompting, task-specific instructions, or human inputs.')