Spaces:
Sleeping
Sleeping
| # CORRECTION KERAS - À EXÉCUTER EN PREMIER | |
| import os | |
| os.environ['TF_USE_LEGACY_KERAS'] = '1' | |
| # Imports standards | |
| import re | |
| # Pour importer les fichiers PDF | |
| import requests | |
| import PyPDF2 | |
| # Traitement du texte | |
| import nltk | |
| # Téléchargement silencieux des données NLTK (seulement si nécessaire) | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt', quiet=True) | |
| try: | |
| nltk.data.find('tokenizers/punkt_tab') | |
| except LookupError: | |
| nltk.download('punkt_tab', quiet=True) | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # Modèle Reranker | |
| from FlagEmbedding import FlagReranker | |
| # Objet text retriever | |
| from langchain_chroma import Chroma | |
| from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
| # LLM via HuggingFace Inference API (au lieu de llama-cpp) | |
| from huggingface_hub import InferenceClient | |
| import gradio as gr | |
| """# GESTION DE LA BASE DE DONNÉES - VARIABLES GLOBALES""" | |
| # Variables globales pour lazy loading | |
| retriever = None | |
| is_initialized = False | |
| """## Etape 3 : Traitement du texte en chunks propres""" | |
| #### FONCTIONS #### | |
| # Segmentation du texte de base | |
| def splitting_by_numer_of_words(text, chunk_size): | |
| """ | |
| Découpe un texte en chunks de taille donnée (nombre de caractères). | |
| Args: | |
| text (str): Le texte à splitter. | |
| chunk_size (int): La taille souhaitée des chunks (nombre de mots). | |
| Returns: | |
| list: Une liste de chunks de texte. | |
| """ | |
| chunks = [] | |
| for phrase in text.split('\n'): | |
| words = phrase.split() | |
| for i in range(0, len(words), chunk_size): | |
| chunks.append(' '.join(words[i:i + chunk_size])) | |
| return chunks | |
| # Fonction de splitting par phrase | |
| def splitting_by_sentences(text): | |
| """ | |
| Découpe un texte en chunks par phrases. | |
| Args: | |
| text (str): Le texte à découper. | |
| Returns: | |
| list: Une liste de chunks de texte (phrases). | |
| """ | |
| sentences = [] | |
| list_paragraph = text.split("\n") | |
| for paragraph in list_paragraph: | |
| list_sent = paragraph.split(".") | |
| sentences = sentences + list_sent | |
| return sentences | |
| # Nettoyage du contenu de chaque chunk | |
| special_chars = [" ", '-', '&', '(', ')', '_', ';', '†', '+', '–', "'", '!', '[', ']', "'", '́', '̀', '\u2009', '\u200b', '\u202f', '©', '£', '§', '°', '@', '€', '$', '\xa0', '~','\n','�'] | |
| def remove_char(text, char): | |
| """Remove each specific character from the text for each character in the chars list.""" | |
| return text.replace(char, ' ') | |
| def remove_chars(text, chars): | |
| """ Apply remove_char() function to text """ | |
| for char in chars: | |
| text = remove_char(text, char) | |
| return text | |
| def remove_multiple_white_spaces(text): | |
| """Remove multiple spaces.""" | |
| text = re.sub(" +", " ", text) | |
| return text | |
| def clean_text(text, special_chars=special_chars): | |
| """Generate a text without chars expect points and comma and multiple white spaces.""" | |
| text = remove_chars(text, special_chars) | |
| text = remove_multiple_white_spaces(text) | |
| return text | |
| # Filtrage des mots vides | |
| def contains_mainly_digits(text, threshold=0.5): | |
| """ | |
| Checks if a text string contains a high percentage of digits compared to letters. | |
| Args: | |
| text (str): The input text to analyze. | |
| threshold (float, optional): The threshold value for the proportion of digits to letters. | |
| Defaults to 0.5. | |
| Returns: | |
| bool: True if the proportion of digits in the text exceeds the threshold, False otherwise. | |
| """ | |
| if not text: | |
| return False | |
| letters_count = 0 | |
| nbs_count = 0 | |
| for char in text: | |
| if char.isalpha(): | |
| letters_count += 1 | |
| elif char.isdigit(): | |
| nbs_count += 1 | |
| if letters_count + nbs_count > 0: | |
| digits_pct = (nbs_count / (letters_count + nbs_count)) | |
| else: | |
| return True | |
| return digits_pct > threshold | |
| def remove_mostly_digits_chunks(chunks, threshold=0.5): | |
| return [chunk for chunk in chunks if not contains_mainly_digits(chunk['content'])] | |
| """# IMPLEMENTATION DU MODELE DE RECHERCHE RETENU""" | |
| class TextRetriever: | |
| def __init__(self, embedding_model_name="mixedbread-ai/mxbai-embed-large-v1", reranking_model_name="BAAI/bge-reranker-large"): | |
| """ | |
| Initialise les modèles d'embedding et de reranking. | |
| Args: | |
| embedding_model_name (str): Nom du modèle d'embedding. | |
| reranking_model_name (str): Nom du modèle de reranking. | |
| """ | |
| print(f"Loading embedding model: {embedding_model_name}") | |
| self.embedding_model = SentenceTransformerEmbeddings(model_name=embedding_model_name) | |
| print(f"Loading reranker model: {reranking_model_name}") | |
| self.reranker_model = FlagReranker(reranking_model_name, use_fp16=True) | |
| self.vector_database = None # Initialisation de la base de données vectorielle à None | |
| def store_embeddings(self, chunks, path="./chroma_db"): | |
| """ | |
| Stocke les embeddings des chunks de texte dans une base de données vectorielle. | |
| Args: | |
| chunks (list of str): Liste de chunks de texte à stocker. | |
| path (str): Chemin du répertoire où la base de données sera stockée. | |
| """ | |
| print(f"Storing embeddings to {path}...") | |
| self.vector_database = Chroma.from_texts(chunks, embedding=self.embedding_model, persist_directory=path) | |
| print("Embeddings stored successfully") | |
| def load_embeddings(self, path): | |
| """ | |
| Charge les embeddings depuis une base de données vectorielle. | |
| Args: | |
| path (str): Chemin du répertoire de la base de données. | |
| """ | |
| print(f"Loading embeddings from {path}...") | |
| self.vector_database = Chroma(persist_directory=path, embedding_function=self.embedding_model) | |
| print("Embeddings loaded successfully") | |
| def get_best_chunks(self, query, top_k=3): | |
| """ | |
| Recherche les meilleurs chunks correspondant à une requête. | |
| Args: | |
| query (str): Requête de recherche. | |
| top_k (int): Nombre de meilleurs chunks à retourner. | |
| Returns: | |
| list: Liste des meilleurs chunks correspondant à la requête. | |
| """ | |
| best_chunks = self.vector_database.similarity_search(query, k=top_k) | |
| return best_chunks | |
| def rerank_chunks(self, query, chunks): | |
| """ | |
| Retourne le chunk le plus pertinent pour une requête donnée. | |
| Args: | |
| query (str): Requête de recherche. | |
| chunks (list): Liste des chunks à re-classer. | |
| Returns: | |
| list: Liste des chunks triés par pertinence. | |
| """ | |
| best_chunks = self.get_best_chunks(query, top_k=10) | |
| rerank_scores = [] | |
| chunk_texts = [chunk.page_content if hasattr(chunk, 'page_content') else str(chunk) for chunk in best_chunks] | |
| for text in chunk_texts: | |
| score = self.reranker_model.compute_score([query, text]) | |
| rerank_scores.append(score) | |
| return [x for _, x in sorted(zip(rerank_scores, best_chunks), reverse=True)] | |
| def get_context(self, query): | |
| """ | |
| Retourne le chunk le plus pertinent pour une requête donnée. | |
| Args: | |
| query (str): Requête de recherche. | |
| Returns: | |
| str: Contenu du chunk le plus pertinent. | |
| """ | |
| best_chunks = self.get_best_chunks(query, top_k=1) | |
| return best_chunks[0].page_content | |
| """# FONCTION D'INITIALISATION LAZY""" | |
| def initialize_system(): | |
| """ | |
| Initialise le système RAG de manière lazy (seulement au premier appel). | |
| Télécharge les PDFs, extrait le texte, crée les chunks et les embeddings. | |
| """ | |
| global retriever, is_initialized | |
| if is_initialized: | |
| return "Système déjà initialisé" | |
| try: | |
| print("=" * 50) | |
| print("INITIALISATION DU SYSTÈME RAG") | |
| print("=" * 50) | |
| # Etape 1: Téléchargement des PDFs | |
| chemin_dossier = "./RAG_IPCC" | |
| if not os.path.exists(chemin_dossier): | |
| os.makedirs(chemin_dossier) | |
| urls = { "6th_report": "https://www.ipcc.ch/report/ar6/syr/downloads/report/IPCC_AR6_SYR_FullVolume.pdf" } | |
| for name, url in urls.items(): | |
| file_path = os.path.join(chemin_dossier, f"{name}.pdf") | |
| if not os.path.exists(file_path): | |
| print(f"📥 Téléchargement de {name}...") | |
| response = requests.get(url) | |
| with open(file_path, 'wb') as file: | |
| file.write(response.content) | |
| print(f"✅ {name} téléchargé") | |
| else: | |
| print(f"✅ {name} existe déjà") | |
| # Etape 2: Extraction du texte | |
| print("\n📄 Extraction du texte des PDFs...") | |
| fichiers_pdf = [f for f in os.listdir(chemin_dossier) if f.endswith('.pdf')] | |
| extracted_text = [] | |
| for pdf in fichiers_pdf: | |
| chemin_pdf = os.path.join(chemin_dossier, pdf) | |
| with open(chemin_pdf, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page_num in range(len(pdf_reader.pages)): | |
| page = pdf_reader.pages[page_num] | |
| text = page.extract_text() | |
| extracted_text.append({"document": pdf, "page": page_num, "content": text}) | |
| print(f"✅ {len(extracted_text)} pages extraites") | |
| # Etape 3: Création des chunks | |
| print("\n✂️ Création des chunks de texte...") | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=20, | |
| length_function=len, | |
| is_separator_regex=False, | |
| ) | |
| chunks = [] | |
| for page_content in extracted_text: | |
| chunks_list = text_splitter.split_text(page_content['content']) | |
| for chunk in chunks_list: | |
| text = clean_text(chunk) | |
| chunks.append({"document": page_content['document'], | |
| "page": page_content['page'], | |
| "content": text}) | |
| chunks = remove_mostly_digits_chunks(chunks) | |
| print(f"✅ {len(chunks)} chunks créés") | |
| # Etape 4: Initialisation du retriever et des embeddings | |
| print("\n🤖 Initialisation du TextRetriever...") | |
| retriever = TextRetriever() | |
| all_chunks = [chunk['content'] for chunk in chunks] | |
| # Vérifier si la base de données existe déjà | |
| db_path = "./chroma_db" | |
| if os.path.exists(db_path): | |
| print("📂 Chargement de la base de données existante...") | |
| retriever.load_embeddings(db_path) | |
| else: | |
| print("🔨 Création de la base de données d'embeddings...") | |
| retriever.store_embeddings(all_chunks, db_path) | |
| is_initialized = True | |
| print("\n" + "=" * 50) | |
| print("✅ SYSTÈME INITIALISÉ AVEC SUCCÈS") | |
| print("=" * 50) | |
| return "✅ Système initialisé avec succès !" | |
| except Exception as e: | |
| print(f"❌ Erreur lors de l'initialisation: {str(e)}") | |
| return f"❌ Erreur: {str(e)}" | |
| """# MODELE LLM | |
| ## Etape 1 : Generation d'une réponse avec HuggingFace Inference API | |
| """ | |
| # 🔒 Récupérer le token HF depuis les variables d'environnement (Repository Secrets) | |
| HF_API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| # Configuration de l'API HuggingFace Inference | |
| # Utiliser un modèle plus petit et compatible avec le tier gratuit | |
| MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" | |
| API_URL = f"/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2F%3Cspan class="hljs-subst">{MODEL_NAME}" | |
| headers = {"Authorization": f"Bearer {HF_API_KEY}"} if HF_API_KEY else {} | |
| # Initialiser le client d'inférence HuggingFace | |
| llm_client = InferenceClient(token=HF_API_KEY) if HF_API_KEY else InferenceClient() | |
| ## FONCTIONS | |
| # Basic context function. | |
| def get_context_from_query(query): | |
| chunks = retriever.get_best_chunks(query, 4) | |
| # Extraire le texte des chunks | |
| context_parts = [] | |
| for chunk in chunks: | |
| if hasattr(chunk, 'page_content'): | |
| context_parts.append(chunk.page_content) | |
| else: | |
| context_parts.append(str(chunk)) | |
| return "\n\n".join(context_parts) | |
| """## Etape 2 : Sauvegarde d'un historique limité de conversation""" | |
| class ConversationHistoryLoader: | |
| def __init__(self, k): | |
| self.k=k | |
| self.conversation_history = [] | |
| # Fonction qui permet créer un prompt (string) sur l'historique de conversation. | |
| def create_conversation_history_prompt(self): | |
| conversation = '' | |
| if self.conversation_history == None or len(self.conversation_history) == 0: | |
| return conversation | |
| else: | |
| for exchange in reversed(self.conversation_history): | |
| conversation = conversation + '\nHuman: '+exchange['Human']+'\nAI: '+exchange['AI'] | |
| return conversation | |
| # Fonction qui permet de mettre à jour l'historique de conversation | |
| # à partir de la dernière query et la dernière réponse du LLM. | |
| def update_conversation_history(self, query, response): | |
| exchange = {'Human': query, 'AI': response} | |
| self.conversation_history.insert(0, exchange) | |
| if len(self.conversation_history) > self.k: | |
| self.conversation_history.pop() | |
| # Fonction pour générer une réponse avec le contexte et l'historique | |
| def generate_response_with_context(instruction, context, chat_history=""): | |
| """ | |
| Génère une réponse en utilisant l'API HuggingFace Inference. | |
| """ | |
| # Construire le prompt complet | |
| prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request using the context provided and the previous conversation. | |
| Context: {context} | |
| {chat_history} | |
| Human: [INST] {instruction} [/INST] | |
| AI: | |
| """ | |
| try: | |
| # Construire les messages pour le chat | |
| system_message = f"""Tu es un assistant expert sur le changement climatique. Réponds aux questions en français en utilisant le contexte fourni des rapports IPCC. | |
| Contexte: {context}""" | |
| messages = [ | |
| {"role": "system", "content": system_message} | |
| ] | |
| # Ajouter l'historique si présent | |
| if chat_history: | |
| messages.append({"role": "assistant", "content": f"Historique:\n{chat_history}"}) | |
| # Ajouter la question | |
| messages.append({"role": "user", "content": instruction}) | |
| # Appeler l'API HuggingFace pour générer la réponse | |
| # Utilisation de Mistral avec chat_completion | |
| response = llm_client.chat_completion( | |
| messages=messages, | |
| model=MODEL_NAME, | |
| max_tokens=300, | |
| temperature=0.7, | |
| top_p=0.95 | |
| ) | |
| # Extraire le contenu de la réponse | |
| answer = response.choices[0].message.content | |
| # Nettoyer la réponse | |
| answer = answer.strip() | |
| answer = re.sub(r"\[context\..*?\]", "", answer) | |
| answer = re.sub(r"Al:\s*", "", answer) | |
| answer = re.sub(r"AI:\s*", "", answer) | |
| return answer | |
| except Exception as e: | |
| print(f"Erreur lors de la génération: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| error_msg = str(e) | |
| # Messages d'aide selon le type d'erreur | |
| if "rate limit" in error_msg.lower(): | |
| return f"⏱️ Rate limit atteint. Veuillez réessayer dans quelques instants.\n\nDétails: {error_msg}" | |
| elif "loading" in error_msg.lower() or "is currently loading" in error_msg.lower(): | |
| return f"⏳ Le modèle est en cours de chargement. Veuillez patienter 20-30 secondes et réessayer.\n\nDétails: {error_msg}" | |
| elif "authorization" in error_msg.lower() or "token" in error_msg.lower(): | |
| return f"🔒 Problème d'authentification.\n\nDétails: {error_msg}\n\n⚠️ Vérifiez que le token HF_TOKEN dans Settings a les permissions 'read' ou 'inference'." | |
| else: | |
| return f"❌ Erreur: {error_msg}\n\nConsultez les logs de votre Space pour plus de détails." | |
| # Créer l instance de gestion d historique | |
| ch = ConversationHistoryLoader(k=3) | |
| # Fonction principale pour répondre aux questions | |
| def get_response(query): | |
| global retriever, is_initialized | |
| try: | |
| # Initialiser le système au premier appel | |
| if not is_initialized: | |
| init_message = initialize_system() | |
| if "❌" in init_message: | |
| return init_message | |
| # Vérifier que le retriever est bien initialisé | |
| if retriever is None: | |
| return "❌ Le système n'est pas correctement initialisé. Veuillez réessayer." | |
| # Obtenir le contexte pertinent | |
| context = get_context_from_query(query) | |
| # Générer la réponse avec contexte et historique | |
| chat_history = ch.create_conversation_history_prompt() | |
| response = generate_response_with_context(query, context, chat_history) | |
| # Mettre à jour l historique | |
| ch.update_conversation_history(query, response) | |
| return response | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Erreur détaillée: {error_details}") | |
| return f"Erreur: {str(e)}" | |
| # Interface Gradio | |
| print("Creating Gradio interface...") | |
| iface = gr.Interface( | |
| fn=get_response, | |
| inputs=gr.Textbox(lines=2, placeholder="Posez votre question sur le climat..."), | |
| outputs=gr.Textbox(lines=5, label="Réponse"), | |
| title="🌍 RAG Chatbot - Questions Climatiques", | |
| description="""Posez vos questions sur le changement climatique basées sur les rapports IPCC. | |
| ⚠️ **Note**: Le système s'initialise automatiquement au premier appel (téléchargement du PDF + création des embeddings). | |
| La première requête peut prendre 2-3 minutes. Les requêtes suivantes seront rapides !""", | |
| examples=[ | |
| "Quels sont les principaux impacts du réchauffement climatique ?", | |
| "Comment les océans sont-ils affectés par le changement climatique ?", | |
| "Quelles sont les solutions pour réduire les émissions ?" | |
| ], | |
| cache_examples=False # Désactive le cache pour éviter l'initialisation au démarrage | |
| ) | |
| # Lancer l application | |
| if __name__ == "__main__": | |
| print("Launching Gradio app...") | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |