# CORRECTION KERAS - À EXÉCUTER EN PREMIER import os os.environ['TF_USE_LEGACY_KERAS'] = '1' # Imports standards import re # Pour importer les fichiers PDF import requests import PyPDF2 # Traitement du texte import nltk # Téléchargement silencieux des données NLTK (seulement si nécessaire) try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt', quiet=True) try: nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt_tab', quiet=True) from langchain_text_splitters import RecursiveCharacterTextSplitter # Modèle Reranker from FlagEmbedding import FlagReranker # Objet text retriever from langchain_chroma import Chroma from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings # LLM via HuggingFace Inference API (au lieu de llama-cpp) from huggingface_hub import InferenceClient import gradio as gr """# GESTION DE LA BASE DE DONNÉES - VARIABLES GLOBALES""" # Variables globales pour lazy loading retriever = None is_initialized = False """## Etape 3 : Traitement du texte en chunks propres""" #### FONCTIONS #### # Segmentation du texte de base def splitting_by_numer_of_words(text, chunk_size): """ Découpe un texte en chunks de taille donnée (nombre de caractères). Args: text (str): Le texte à splitter. chunk_size (int): La taille souhaitée des chunks (nombre de mots). Returns: list: Une liste de chunks de texte. """ chunks = [] for phrase in text.split('\n'): words = phrase.split() for i in range(0, len(words), chunk_size): chunks.append(' '.join(words[i:i + chunk_size])) return chunks # Fonction de splitting par phrase def splitting_by_sentences(text): """ Découpe un texte en chunks par phrases. Args: text (str): Le texte à découper. Returns: list: Une liste de chunks de texte (phrases). """ sentences = [] list_paragraph = text.split("\n") for paragraph in list_paragraph: list_sent = paragraph.split(".") sentences = sentences + list_sent return sentences # Nettoyage du contenu de chaque chunk special_chars = [" ", '-', '&', '(', ')', '_', ';', '†', '+', '–', "'", '!', '[', ']', "'", '́', '̀', '\u2009', '\u200b', '\u202f', '©', '£', '§', '°', '@', '€', '$', '\xa0', '~','\n','�'] def remove_char(text, char): """Remove each specific character from the text for each character in the chars list.""" return text.replace(char, ' ') def remove_chars(text, chars): """ Apply remove_char() function to text """ for char in chars: text = remove_char(text, char) return text def remove_multiple_white_spaces(text): """Remove multiple spaces.""" text = re.sub(" +", " ", text) return text def clean_text(text, special_chars=special_chars): """Generate a text without chars expect points and comma and multiple white spaces.""" text = remove_chars(text, special_chars) text = remove_multiple_white_spaces(text) return text # Filtrage des mots vides def contains_mainly_digits(text, threshold=0.5): """ Checks if a text string contains a high percentage of digits compared to letters. Args: text (str): The input text to analyze. threshold (float, optional): The threshold value for the proportion of digits to letters. Defaults to 0.5. Returns: bool: True if the proportion of digits in the text exceeds the threshold, False otherwise. """ if not text: return False letters_count = 0 nbs_count = 0 for char in text: if char.isalpha(): letters_count += 1 elif char.isdigit(): nbs_count += 1 if letters_count + nbs_count > 0: digits_pct = (nbs_count / (letters_count + nbs_count)) else: return True return digits_pct > threshold def remove_mostly_digits_chunks(chunks, threshold=0.5): return [chunk for chunk in chunks if not contains_mainly_digits(chunk['content'])] """# IMPLEMENTATION DU MODELE DE RECHERCHE RETENU""" class TextRetriever: def __init__(self, embedding_model_name="mixedbread-ai/mxbai-embed-large-v1", reranking_model_name="BAAI/bge-reranker-large"): """ Initialise les modèles d'embedding et de reranking. Args: embedding_model_name (str): Nom du modèle d'embedding. reranking_model_name (str): Nom du modèle de reranking. """ print(f"Loading embedding model: {embedding_model_name}") self.embedding_model = SentenceTransformerEmbeddings(model_name=embedding_model_name) print(f"Loading reranker model: {reranking_model_name}") self.reranker_model = FlagReranker(reranking_model_name, use_fp16=True) self.vector_database = None # Initialisation de la base de données vectorielle à None def store_embeddings(self, chunks, path="./chroma_db"): """ Stocke les embeddings des chunks de texte dans une base de données vectorielle. Args: chunks (list of str): Liste de chunks de texte à stocker. path (str): Chemin du répertoire où la base de données sera stockée. """ print(f"Storing embeddings to {path}...") self.vector_database = Chroma.from_texts(chunks, embedding=self.embedding_model, persist_directory=path) print("Embeddings stored successfully") def load_embeddings(self, path): """ Charge les embeddings depuis une base de données vectorielle. Args: path (str): Chemin du répertoire de la base de données. """ print(f"Loading embeddings from {path}...") self.vector_database = Chroma(persist_directory=path, embedding_function=self.embedding_model) print("Embeddings loaded successfully") def get_best_chunks(self, query, top_k=3): """ Recherche les meilleurs chunks correspondant à une requête. Args: query (str): Requête de recherche. top_k (int): Nombre de meilleurs chunks à retourner. Returns: list: Liste des meilleurs chunks correspondant à la requête. """ best_chunks = self.vector_database.similarity_search(query, k=top_k) return best_chunks def rerank_chunks(self, query, chunks): """ Retourne le chunk le plus pertinent pour une requête donnée. Args: query (str): Requête de recherche. chunks (list): Liste des chunks à re-classer. Returns: list: Liste des chunks triés par pertinence. """ best_chunks = self.get_best_chunks(query, top_k=10) rerank_scores = [] chunk_texts = [chunk.page_content if hasattr(chunk, 'page_content') else str(chunk) for chunk in best_chunks] for text in chunk_texts: score = self.reranker_model.compute_score([query, text]) rerank_scores.append(score) return [x for _, x in sorted(zip(rerank_scores, best_chunks), reverse=True)] def get_context(self, query): """ Retourne le chunk le plus pertinent pour une requête donnée. Args: query (str): Requête de recherche. Returns: str: Contenu du chunk le plus pertinent. """ best_chunks = self.get_best_chunks(query, top_k=1) return best_chunks[0].page_content """# FONCTION D'INITIALISATION LAZY""" def initialize_system(): """ Initialise le système RAG de manière lazy (seulement au premier appel). Télécharge les PDFs, extrait le texte, crée les chunks et les embeddings. """ global retriever, is_initialized if is_initialized: return "Système déjà initialisé" try: print("=" * 50) print("INITIALISATION DU SYSTÈME RAG") print("=" * 50) # Etape 1: Téléchargement des PDFs chemin_dossier = "./RAG_IPCC" if not os.path.exists(chemin_dossier): os.makedirs(chemin_dossier) urls = { "6th_report": "https://www.ipcc.ch/report/ar6/syr/downloads/report/IPCC_AR6_SYR_FullVolume.pdf" } for name, url in urls.items(): file_path = os.path.join(chemin_dossier, f"{name}.pdf") if not os.path.exists(file_path): print(f"📥 Téléchargement de {name}...") response = requests.get(url) with open(file_path, 'wb') as file: file.write(response.content) print(f"✅ {name} téléchargé") else: print(f"✅ {name} existe déjà") # Etape 2: Extraction du texte print("\n📄 Extraction du texte des PDFs...") fichiers_pdf = [f for f in os.listdir(chemin_dossier) if f.endswith('.pdf')] extracted_text = [] for pdf in fichiers_pdf: chemin_pdf = os.path.join(chemin_dossier, pdf) with open(chemin_pdf, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) for page_num in range(len(pdf_reader.pages)): page = pdf_reader.pages[page_num] text = page.extract_text() extracted_text.append({"document": pdf, "page": page_num, "content": text}) print(f"✅ {len(extracted_text)} pages extraites") # Etape 3: Création des chunks print("\n✂️ Création des chunks de texte...") text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=20, length_function=len, is_separator_regex=False, ) chunks = [] for page_content in extracted_text: chunks_list = text_splitter.split_text(page_content['content']) for chunk in chunks_list: text = clean_text(chunk) chunks.append({"document": page_content['document'], "page": page_content['page'], "content": text}) chunks = remove_mostly_digits_chunks(chunks) print(f"✅ {len(chunks)} chunks créés") # Etape 4: Initialisation du retriever et des embeddings print("\n🤖 Initialisation du TextRetriever...") retriever = TextRetriever() all_chunks = [chunk['content'] for chunk in chunks] # Vérifier si la base de données existe déjà db_path = "./chroma_db" if os.path.exists(db_path): print("📂 Chargement de la base de données existante...") retriever.load_embeddings(db_path) else: print("🔨 Création de la base de données d'embeddings...") retriever.store_embeddings(all_chunks, db_path) is_initialized = True print("\n" + "=" * 50) print("✅ SYSTÈME INITIALISÉ AVEC SUCCÈS") print("=" * 50) return "✅ Système initialisé avec succès !" except Exception as e: print(f"❌ Erreur lors de l'initialisation: {str(e)}") return f"❌ Erreur: {str(e)}" """# MODELE LLM ## Etape 1 : Generation d'une réponse avec HuggingFace Inference API """ # 🔒 Récupérer le token HF depuis les variables d'environnement (Repository Secrets) HF_API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") # Configuration de l'API HuggingFace Inference # Utiliser un modèle plus petit et compatible avec le tier gratuit MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" API_URL = f"/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2F%7BMODEL_NAME%7D" headers = {"Authorization": f"Bearer {HF_API_KEY}"} if HF_API_KEY else {} # Initialiser le client d'inférence HuggingFace llm_client = InferenceClient(token=HF_API_KEY) if HF_API_KEY else InferenceClient() ## FONCTIONS # Basic context function. def get_context_from_query(query): chunks = retriever.get_best_chunks(query, 4) # Extraire le texte des chunks context_parts = [] for chunk in chunks: if hasattr(chunk, 'page_content'): context_parts.append(chunk.page_content) else: context_parts.append(str(chunk)) return "\n\n".join(context_parts) """## Etape 2 : Sauvegarde d'un historique limité de conversation""" class ConversationHistoryLoader: def __init__(self, k): self.k=k self.conversation_history = [] # Fonction qui permet créer un prompt (string) sur l'historique de conversation. def create_conversation_history_prompt(self): conversation = '' if self.conversation_history == None or len(self.conversation_history) == 0: return conversation else: for exchange in reversed(self.conversation_history): conversation = conversation + '\nHuman: '+exchange['Human']+'\nAI: '+exchange['AI'] return conversation # Fonction qui permet de mettre à jour l'historique de conversation # à partir de la dernière query et la dernière réponse du LLM. def update_conversation_history(self, query, response): exchange = {'Human': query, 'AI': response} self.conversation_history.insert(0, exchange) if len(self.conversation_history) > self.k: self.conversation_history.pop() # Fonction pour générer une réponse avec le contexte et l'historique def generate_response_with_context(instruction, context, chat_history=""): """ Génère une réponse en utilisant l'API HuggingFace Inference. """ # Construire le prompt complet prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request using the context provided and the previous conversation. Context: {context} {chat_history} Human: [INST] {instruction} [/INST] AI: """ try: # Construire les messages pour le chat system_message = f"""Tu es un assistant expert sur le changement climatique. Réponds aux questions en français en utilisant le contexte fourni des rapports IPCC. Contexte: {context}""" messages = [ {"role": "system", "content": system_message} ] # Ajouter l'historique si présent if chat_history: messages.append({"role": "assistant", "content": f"Historique:\n{chat_history}"}) # Ajouter la question messages.append({"role": "user", "content": instruction}) # Appeler l'API HuggingFace pour générer la réponse # Utilisation de Mistral avec chat_completion response = llm_client.chat_completion( messages=messages, model=MODEL_NAME, max_tokens=300, temperature=0.7, top_p=0.95 ) # Extraire le contenu de la réponse answer = response.choices[0].message.content # Nettoyer la réponse answer = answer.strip() answer = re.sub(r"\[context\..*?\]", "", answer) answer = re.sub(r"Al:\s*", "", answer) answer = re.sub(r"AI:\s*", "", answer) return answer except Exception as e: print(f"Erreur lors de la génération: {str(e)}") import traceback traceback.print_exc() error_msg = str(e) # Messages d'aide selon le type d'erreur if "rate limit" in error_msg.lower(): return f"⏱️ Rate limit atteint. Veuillez réessayer dans quelques instants.\n\nDétails: {error_msg}" elif "loading" in error_msg.lower() or "is currently loading" in error_msg.lower(): return f"⏳ Le modèle est en cours de chargement. Veuillez patienter 20-30 secondes et réessayer.\n\nDétails: {error_msg}" elif "authorization" in error_msg.lower() or "token" in error_msg.lower(): return f"🔒 Problème d'authentification.\n\nDétails: {error_msg}\n\n⚠️ Vérifiez que le token HF_TOKEN dans Settings a les permissions 'read' ou 'inference'." else: return f"❌ Erreur: {error_msg}\n\nConsultez les logs de votre Space pour plus de détails." # Créer l instance de gestion d historique ch = ConversationHistoryLoader(k=3) # Fonction principale pour répondre aux questions def get_response(query): global retriever, is_initialized try: # Initialiser le système au premier appel if not is_initialized: init_message = initialize_system() if "❌" in init_message: return init_message # Vérifier que le retriever est bien initialisé if retriever is None: return "❌ Le système n'est pas correctement initialisé. Veuillez réessayer." # Obtenir le contexte pertinent context = get_context_from_query(query) # Générer la réponse avec contexte et historique chat_history = ch.create_conversation_history_prompt() response = generate_response_with_context(query, context, chat_history) # Mettre à jour l historique ch.update_conversation_history(query, response) return response except Exception as e: import traceback error_details = traceback.format_exc() print(f"Erreur détaillée: {error_details}") return f"Erreur: {str(e)}" # Interface Gradio print("Creating Gradio interface...") iface = gr.Interface( fn=get_response, inputs=gr.Textbox(lines=2, placeholder="Posez votre question sur le climat..."), outputs=gr.Textbox(lines=5, label="Réponse"), title="🌍 RAG Chatbot - Questions Climatiques", description="""Posez vos questions sur le changement climatique basées sur les rapports IPCC. ⚠️ **Note**: Le système s'initialise automatiquement au premier appel (téléchargement du PDF + création des embeddings). La première requête peut prendre 2-3 minutes. Les requêtes suivantes seront rapides !""", examples=[ "Quels sont les principaux impacts du réchauffement climatique ?", "Comment les océans sont-ils affectés par le changement climatique ?", "Quelles sont les solutions pour réduire les émissions ?" ], cache_examples=False # Désactive le cache pour éviter l'initialisation au démarrage ) # Lancer l application if __name__ == "__main__": print("Launching Gradio app...") iface.launch(server_name="0.0.0.0", server_port=7860)