{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# Information Retrieval 2/2" ], "metadata": { "id": "IAuoFHU30qeb" } }, { "cell_type": "markdown", "source": [ "Dans cette partie, l'objectif est de mettre en application les meilleurs méthodes d'Information Retrieval pour votre outil de RAG.\n", "\n", "\n", "\n", "Pour cela vous devrez :\n", "\n", "\n", "1. Déterminer les modèles les plus appropriés (s'appuyer sur la partie précédente)\n", "2. Pour les modèles d'embeddings, déterminer une base de donnée vectorielle appropriée\n", "3. Implémenter la classe si dessous.\n", "\n", "Conseils :\n", "* Le framework *langchain* permet de réaliser ces tâches assez simplement.\n", "* Vous pouvez tester les méthodes sur les passages test de la partie précédente et sur les documents de la première séance.\n", "* Testez plusieurs méthodes (différents paramètres, renvoie d'un ou plusieurs passages concaténés...)" ], "metadata": { "id": "vy9VyS8IURLS" } }, { "cell_type": "code", "source": [ "!pip install langchain\n", "import langchain as lc\n", "\n", "class TextRetriever:\n", " def __init__(self, model_name=\"facebook/bart-base\", embedding_path=\"paraphrase-distilbert-base\"):\n", " \"\"\"\n", " Initialise les modèles et la base de données vectorielle.\n", "\n", " Args:\n", " model_name (str): Nom du modèle de langage à utiliser (ex: \"facebook/bart-base\").\n", " embedding_path (str): Chemin du répertoire où la base de données vectorielle est stockée.\n", " \"\"\"\n", " self.model = lc.load(model_name)\n", " self.embeddings = lc.load_embeddings(embedding_path)\n", "\n", " def store_embeddings(self, chunks, path=\"embeddings.pkl\"):\n", " \"\"\"\n", " Stocke les embeddings des chunks de texte dans une base de données vectorielle.\n", "\n", " Args:\n", " chunks (list of str): Liste de chunks de texte à stocker.\n", " path (str): Chemin du répertoire où la base de données sera stockée.\n", " \"\"\"\n", " embeddings = self.embeddings.encode(chunks)\n", " lc.save_embeddings(embeddings, path)\n", "\n", " def load_embeddings(self, path=\"embeddings.pkl\"):\n", " \"\"\"\n", " Charge les embeddings depuis une base de données vectorielle.\n", "\n", " Args:\n", " path (str): Chemin du répertoire de la base de données.\n", " \"\"\"\n", " self.embeddings = lc.load_embeddings(path)\n", "\n", " def get_best_chunks(self, query, top_k=10):\n", " \"\"\"\n", " Recherche les meilleurs chunks correspondant à une requête.\n", "\n", " Args:\n", " query (str): Requête de recherche.\n", " top_k (int): Nombre de meilleurs chunks à retourner.\n", "\n", " Returns:\n", " list: Liste des meilleurs chunks correspondant à la requête.\n", " \"\"\"\n", " encoded_query = self.embeddings.encode([query])\n", " scores = self.embeddings.cosine_similarity(encoded_query, self.embeddings.vectors)\n", " top_k_indices = scores.argsort(axis=1)[:,-top_k:]\n", " top_chunks = [self.embeddings.docs[i] for i in top_k_indices.flatten()]\n", " return top_chunks\n", "\n", " def rerank_chunks(self, query, chunks):\n", " \"\"\"\n", " Retrie les chunks par pertinence\n", "\n", " Args:\n", " query (str): Requête de recherche.\n", " chunks (list of str): Liste des chunks à reclasser.\n", "\n", " Returns:\n", " list: liste triée des chunks par pertinence\n", " \"\"\"\n", " encoded_query = self.embeddings.encode([query])\n", " scores = self.embeddings.cosine_similarity(encoded_query, self.embeddings.encode(chunks))\n", " sorted_chunks = [chunks[i] for i in scores.argsort(axis=0)[:-1]]\n", " return sorted_chunks\n", "\n", " def get_context(self, query):\n", " \"\"\"\n", " Retourne un texte contenant les informations pertinentes pour la requête.\n", "\n", " Args:\n", " query (str): Requête de recherche.\n", "\n", " Returns:\n", " str: texte pertinent pour répondre\n", " \"\"\"\n", " top_chunks = self.get_best_chunks(query)\n", " return \" \".join(top_chunks)\n", "\n" ], "metadata": { "id": "uVSyXkiUNnVC", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "fb29e054-96a2-423a-d601-8602073a4f6f" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting langchain\n", " Downloading langchain-0.2.1-py3-none-any.whl (973 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m973.5/973.5 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (6.0.1)\n", "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.0.30)\n", "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (3.9.5)\n", "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.3)\n", "Collecting langchain-core<0.3.0,>=0.2.0 (from langchain)\n", " Downloading langchain_core-0.2.3-py3-none-any.whl (310 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.2/310.2 kB\u001b[0m \u001b[31m31.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)\n", " Downloading langchain_text_splitters-0.2.0-py3-none-any.whl (23 kB)\n", "Collecting langsmith<0.2.0,>=0.1.17 (from langchain)\n", " Downloading langsmith-0.1.69-py3-none-any.whl (124 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.4/124.4 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.25.2)\n", "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.7.2)\n", "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.31.0)\n", "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (8.3.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n", "Collecting jsonpatch<2.0,>=1.33 (from langchain-core<0.3.0,>=0.2.0->langchain)\n", " Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)\n", "Collecting packaging<24.0,>=23.2 (from langchain-core<0.3.0,>=0.2.0->langchain)\n", " Downloading packaging-23.2-py3-none-any.whl (53 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.0/53.0 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.17->langchain)\n", " Downloading orjson-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (142 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m142.5/142.5 kB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (0.7.0)\n", "Requirement already satisfied: pydantic-core==2.18.3 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (2.18.3)\n", "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (4.12.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2024.2.2)\n", "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain) (3.0.3)\n", "Collecting jsonpointer>=1.9 (from jsonpatch<2.0,>=1.33->langchain-core<0.3.0,>=0.2.0->langchain)\n", " Downloading jsonpointer-2.4-py2.py3-none-any.whl (7.8 kB)\n", "Installing collected packages: packaging, orjson, jsonpointer, jsonpatch, langsmith, langchain-core, langchain-text-splitters, langchain\n", " Attempting uninstall: packaging\n", " Found existing installation: packaging 24.0\n", " Uninstalling packaging-24.0:\n", " Successfully uninstalled packaging-24.0\n", "Successfully installed jsonpatch-1.33 jsonpointer-2.4 langchain-0.2.1 langchain-core-0.2.3 langchain-text-splitters-0.2.0 langsmith-0.1.69 orjson-3.10.3 packaging-23.2\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Questions\n", "\n", "\n", "* Décrivez et motivez la méthode choisie (modèles utilisés, nombre de passages renvoyés...)\n", "* Comment adapter la solution en cas de base de données plus grande?\n", "* Quels sont les avantages à utiliser une base de données vectorielle pour stocker les embeddings?\n", "\n", "\n", "\n" ], "metadata": { "id": "8YhkcgBwWpC5" } } ] }