Hachicha Adam commited on
Commit
6810502
·
1 Parent(s): 364f1ef

Upload New File

Browse files
Information_Retrieval_2_élèves__2_.ipynb ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "# Information Retrieval 2/2"
21
+ ],
22
+ "metadata": {
23
+ "id": "IAuoFHU30qeb"
24
+ }
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "source": [
29
+ "Dans cette partie, l'objectif est de mettre en application les meilleurs méthodes d'Information Retrieval pour votre outil de RAG.\n",
30
+ "\n",
31
+ "\n",
32
+ "\n",
33
+ "Pour cela vous devrez :\n",
34
+ "\n",
35
+ "\n",
36
+ "1. Déterminer les modèles les plus appropriés (s'appuyer sur la partie précédente)\n",
37
+ "2. Pour les modèles d'embeddings, déterminer une base de donnée vectorielle appropriée\n",
38
+ "3. Implémenter la classe si dessous.\n",
39
+ "\n",
40
+ "Conseils :\n",
41
+ "* Le framework *langchain* permet de réaliser ces tâches assez simplement.\n",
42
+ "* Vous pouvez tester les méthodes sur les passages test de la partie précédente et sur les documents de la première séance.\n",
43
+ "* Testez plusieurs méthodes (différents paramètres, renvoie d'un ou plusieurs passages concaténés...)"
44
+ ],
45
+ "metadata": {
46
+ "id": "vy9VyS8IURLS"
47
+ }
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "source": [
52
+ "!pip install langchain\n",
53
+ "import langchain as lc\n",
54
+ "\n",
55
+ "class TextRetriever:\n",
56
+ " def __init__(self, model_name=\"facebook/bart-base\", embedding_path=\"paraphrase-distilbert-base\"):\n",
57
+ " \"\"\"\n",
58
+ " Initialise les modèles et la base de données vectorielle.\n",
59
+ "\n",
60
+ " Args:\n",
61
+ " model_name (str): Nom du modèle de langage à utiliser (ex: \"facebook/bart-base\").\n",
62
+ " embedding_path (str): Chemin du répertoire où la base de données vectorielle est stockée.\n",
63
+ " \"\"\"\n",
64
+ " self.model = lc.load(model_name)\n",
65
+ " self.embeddings = lc.load_embeddings(embedding_path)\n",
66
+ "\n",
67
+ " def store_embeddings(self, chunks, path=\"embeddings.pkl\"):\n",
68
+ " \"\"\"\n",
69
+ " Stocke les embeddings des chunks de texte dans une base de données vectorielle.\n",
70
+ "\n",
71
+ " Args:\n",
72
+ " chunks (list of str): Liste de chunks de texte à stocker.\n",
73
+ " path (str): Chemin du répertoire où la base de données sera stockée.\n",
74
+ " \"\"\"\n",
75
+ " embeddings = self.embeddings.encode(chunks)\n",
76
+ " lc.save_embeddings(embeddings, path)\n",
77
+ "\n",
78
+ " def load_embeddings(self, path=\"embeddings.pkl\"):\n",
79
+ " \"\"\"\n",
80
+ " Charge les embeddings depuis une base de données vectorielle.\n",
81
+ "\n",
82
+ " Args:\n",
83
+ " path (str): Chemin du répertoire de la base de données.\n",
84
+ " \"\"\"\n",
85
+ " self.embeddings = lc.load_embeddings(path)\n",
86
+ "\n",
87
+ " def get_best_chunks(self, query, top_k=10):\n",
88
+ " \"\"\"\n",
89
+ " Recherche les meilleurs chunks correspondant à une requête.\n",
90
+ "\n",
91
+ " Args:\n",
92
+ " query (str): Requête de recherche.\n",
93
+ " top_k (int): Nombre de meilleurs chunks à retourner.\n",
94
+ "\n",
95
+ " Returns:\n",
96
+ " list: Liste des meilleurs chunks correspondant à la requête.\n",
97
+ " \"\"\"\n",
98
+ " encoded_query = self.embeddings.encode([query])\n",
99
+ " scores = self.embeddings.cosine_similarity(encoded_query, self.embeddings.vectors)\n",
100
+ " top_k_indices = scores.argsort(axis=1)[:,-top_k:]\n",
101
+ " top_chunks = [self.embeddings.docs[i] for i in top_k_indices.flatten()]\n",
102
+ " return top_chunks\n",
103
+ "\n",
104
+ " def rerank_chunks(self, query, chunks):\n",
105
+ " \"\"\"\n",
106
+ " Retrie les chunks par pertinence\n",
107
+ "\n",
108
+ " Args:\n",
109
+ " query (str): Requête de recherche.\n",
110
+ " chunks (list of str): Liste des chunks à reclasser.\n",
111
+ "\n",
112
+ " Returns:\n",
113
+ " list: liste triée des chunks par pertinence\n",
114
+ " \"\"\"\n",
115
+ " encoded_query = self.embeddings.encode([query])\n",
116
+ " scores = self.embeddings.cosine_similarity(encoded_query, self.embeddings.encode(chunks))\n",
117
+ " sorted_chunks = [chunks[i] for i in scores.argsort(axis=0)[:-1]]\n",
118
+ " return sorted_chunks\n",
119
+ "\n",
120
+ " def get_context(self, query):\n",
121
+ " \"\"\"\n",
122
+ " Retourne un texte contenant les informations pertinentes pour la requête.\n",
123
+ "\n",
124
+ " Args:\n",
125
+ " query (str): Requête de recherche.\n",
126
+ "\n",
127
+ " Returns:\n",
128
+ " str: texte pertinent pour répondre\n",
129
+ " \"\"\"\n",
130
+ " top_chunks = self.get_best_chunks(query)\n",
131
+ " return \" \".join(top_chunks)\n",
132
+ "\n"
133
+ ],
134
+ "metadata": {
135
+ "id": "uVSyXkiUNnVC",
136
+ "colab": {
137
+ "base_uri": "https://localhost:8080/"
138
+ },
139
+ "outputId": "fb29e054-96a2-423a-d601-8602073a4f6f"
140
+ },
141
+ "execution_count": 2,
142
+ "outputs": [
143
+ {
144
+ "output_type": "stream",
145
+ "name": "stdout",
146
+ "text": [
147
+ "Collecting langchain\n",
148
+ " Downloading langchain-0.2.1-py3-none-any.whl (973 kB)\n",
149
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m973.5/973.5 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
150
+ "\u001b[?25hRequirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (6.0.1)\n",
151
+ "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.0.30)\n",
152
+ "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (3.9.5)\n",
153
+ "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.3)\n",
154
+ "Collecting langchain-core<0.3.0,>=0.2.0 (from langchain)\n",
155
+ " Downloading langchain_core-0.2.3-py3-none-any.whl (310 kB)\n",
156
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.2/310.2 kB\u001b[0m \u001b[31m31.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
157
+ "\u001b[?25hCollecting langchain-text-splitters<0.3.0,>=0.2.0 (from langchain)\n",
158
+ " Downloading langchain_text_splitters-0.2.0-py3-none-any.whl (23 kB)\n",
159
+ "Collecting langsmith<0.2.0,>=0.1.17 (from langchain)\n",
160
+ " Downloading langsmith-0.1.69-py3-none-any.whl (124 kB)\n",
161
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.4/124.4 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
162
+ "\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.25.2)\n",
163
+ "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.7.2)\n",
164
+ "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.31.0)\n",
165
+ "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (8.3.0)\n",
166
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
167
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n",
168
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n",
169
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n",
170
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n",
171
+ "Collecting jsonpatch<2.0,>=1.33 (from langchain-core<0.3.0,>=0.2.0->langchain)\n",
172
+ " Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)\n",
173
+ "Collecting packaging<24.0,>=23.2 (from langchain-core<0.3.0,>=0.2.0->langchain)\n",
174
+ " Downloading packaging-23.2-py3-none-any.whl (53 kB)\n",
175
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.0/53.0 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
176
+ "\u001b[?25hCollecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.17->langchain)\n",
177
+ " Downloading orjson-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (142 kB)\n",
178
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m142.5/142.5 kB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
179
+ "\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (0.7.0)\n",
180
+ "Requirement already satisfied: pydantic-core==2.18.3 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (2.18.3)\n",
181
+ "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (4.12.0)\n",
182
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.3.2)\n",
183
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.7)\n",
184
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2.0.7)\n",
185
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2024.2.2)\n",
186
+ "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain) (3.0.3)\n",
187
+ "Collecting jsonpointer>=1.9 (from jsonpatch<2.0,>=1.33->langchain-core<0.3.0,>=0.2.0->langchain)\n",
188
+ " Downloading jsonpointer-2.4-py2.py3-none-any.whl (7.8 kB)\n",
189
+ "Installing collected packages: packaging, orjson, jsonpointer, jsonpatch, langsmith, langchain-core, langchain-text-splitters, langchain\n",
190
+ " Attempting uninstall: packaging\n",
191
+ " Found existing installation: packaging 24.0\n",
192
+ " Uninstalling packaging-24.0:\n",
193
+ " Successfully uninstalled packaging-24.0\n",
194
+ "Successfully installed jsonpatch-1.33 jsonpointer-2.4 langchain-0.2.1 langchain-core-0.2.3 langchain-text-splitters-0.2.0 langsmith-0.1.69 orjson-3.10.3 packaging-23.2\n"
195
+ ]
196
+ }
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "source": [
202
+ "## Questions\n",
203
+ "\n",
204
+ "\n",
205
+ "* Décrivez et motivez la méthode choisie (modèles utilisés, nombre de passages renvoyés...)\n",
206
+ "* Comment adapter la solution en cas de base de données plus grande?\n",
207
+ "* Quels sont les avantages à utiliser une base de données vectorielle pour stocker les embeddings?\n",
208
+ "\n",
209
+ "\n",
210
+ "\n"
211
+ ],
212
+ "metadata": {
213
+ "id": "8YhkcgBwWpC5"
214
+ }
215
+ }
216
+ ]
217
+ }