70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
import random
|
|
import gradio as gr
|
|
|
|
__import__('pysqlite3')
|
|
import sys
|
|
import dotenv
|
|
|
|
from langchain.vectorstores import Chroma
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
from langchain.prompts import ChatPromptTemplate, SemanticSimilarityExampleSelector
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.schema.runnable import RunnablePassthrough
|
|
|
|
dotenv.load_dotenv()
|
|
|
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
|
vectorstore = Chroma(persist_directory="./chroma_db.1.0", embedding_function=OpenAIEmbeddings())
|
|
retriever = vectorstore.as_retriever()
|
|
embeddings = OpenAIEmbeddings()
|
|
examples_vectorstore = Chroma(persist_directory="./chroma_db_examples.1.0", embedding_function=OpenAIEmbeddings())
|
|
example_selector = SemanticSimilarityExampleSelector(
|
|
vectorstore=examples_vectorstore,
|
|
k=2,
|
|
)
|
|
|
|
from langchain.prompts import (
|
|
ChatPromptTemplate,
|
|
FewShotChatMessagePromptTemplate,
|
|
)
|
|
|
|
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
|
input_variables=["input"],
|
|
example_selector=example_selector,
|
|
example_prompt=ChatPromptTemplate.from_messages(
|
|
[("human", "{input}"), ("ai", "{output}")]
|
|
)
|
|
)
|
|
|
|
final_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system",
|
|
"You are a lawyer specialized in French Labour laws and promting Unions. You want to help but not misguide the user. If you don't know the answer, just say that you don't know. Don't hesitate to reformulate and think about the issues. Cite law articles as much as possible."),
|
|
few_shot_prompt,
|
|
("human", """
|
|
Question: {question}
|
|
|
|
Context: {context}
|
|
|
|
Answer:
|
|
"""),
|
|
]
|
|
)
|
|
|
|
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.33)
|
|
rag_chain = (
|
|
{"context": retriever, "question": RunnablePassthrough(), "input" : RunnablePassthrough()}
|
|
| final_prompt
|
|
| llm
|
|
)
|
|
|
|
def random_response(message, history):
|
|
res = rag_chain.invoke(message)
|
|
return res.content
|
|
|
|
|
|
demo = gr.ChatInterface(random_response)
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch()
|