code-travail/web-ui.py
2024-01-06 18:25:42 +01:00

109 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import copy
import random
from typing import Iterable
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
__import__('pysqlite3')
import sys
import dotenv
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate, SemanticSimilarityExampleSelector
from langchain.chat_models import ChatOpenAI
from langchain.schema.runnable import RunnablePassthrough
dotenv.load_dotenv()
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
vectorstore = Chroma(persist_directory="./chroma_db.1.0", embedding_function=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
embeddings = OpenAIEmbeddings()
examples_vectorstore = Chroma(persist_directory="./chroma_db_examples.1.0", embedding_function=OpenAIEmbeddings())
example_selector = SemanticSimilarityExampleSelector(
vectorstore=examples_vectorstore,
k=2,
)
from langchain.prompts import (
ChatPromptTemplate,
FewShotChatMessagePromptTemplate,
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
input_variables=["input"],
example_selector=example_selector,
example_prompt=ChatPromptTemplate.from_messages(
[("human", "{input}"), ("ai", "{output}")]
)
)
final_prompt = ChatPromptTemplate.from_messages(
[
("system",
"You are a lawyer specialized in French Labour laws that promotes Unions and regulation authorities. You want to help but not misguide the user. If you don't know the answer, just say that you don't know. Don't hesitate to reformulate and think about the issues. Cite law articles as much as possible."),
few_shot_prompt,
("human", """
Question: {question}
Context: {context}
Answer:
"""),
]
)
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.33)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough(), "input": RunnablePassthrough()}
| final_prompt
| llm
)
def merge_history(message: str, reversed_history: list) -> str:
max_words = 1499
merged_message = []
if len(reversed_history) == 0:
return message
merged_message.extend(
f"Ma question est: <question>{message}</question>. L'historique de la discussion est <history>".split())
for msg,response in reversed_history:
if len(merged_message) >= max_words:
break
words = msg.split(" ")
merged_message.extend(words[0:(max_words - len(merged_message))])
merged_message.append("</history>")
return " ".join(merged_message)
def llm_response(message, history):
reversed_history = copy.deepcopy(history)
reversed_history.reverse()
merged_message = merge_history(message, reversed_history)
res = rag_chain.invoke(merged_message)
return res.content
demo = gr.ChatInterface(
llm_response,
examples=[
"J'ai été embauchée en Contrat à Durée Déterminée d'une durée de 3 mois pour un poste de vendeuse dans un supermarché. J'ai déjà fait 2 semaines mais le poste ne me convient pas. Comment rompre ma période dessai?",
"Je suis dans la fonction publique territoriale, plus précisément agent de maîtrise pour une commune. Est-ce que mon employeur peut modifier mes dates de départ en congé alors qu'elles étaient définies depuis plus d'un mois? "
],
theme=gr.themes.Soft(),
analytics_enabled=False
)
if __name__ == "__main__":
demo.launch(
show_error=True,
)