{ "cells": [ { "cell_type": "markdown", "source": [ "# Cours 6: La toxicité" ], "metadata": { "id": "ND9o13T1WTVV" } }, { "cell_type": "markdown", "source": [ "Dans ce notebook, nous allons explorer comment nous pouvons utiliser le Machine Learning, et plus précisément les modèles de Natural Language Inference (NLI) et de Large Language Model (LLM), pour détecter la toxicité dans le texte et les menaces potentielles contenues dans les queries.\n", "\n", "Les modèles NLI sont des modèles d'apprentissage automatique qui peuvent déterminer la relation entre deux phrases, ce qui peut être utile pour identifier le contexte et la nuance dans le language. D'autre part, les LLM sont capables d'apprendre la structure et les subtilités d'une langue, ce qui peut être utile pour comprendre le texte à un niveau plus profond.\n", "\n", "Mais avant ça amusons-nous un peu 😊" ], "metadata": { "id": "LVOMzP1dZHLg" } }, { "cell_type": "markdown", "source": [ "## Comment prévenir la toxicité en IA\n", "\n", "Allez sur https://gandalf.lakera.ai/. Amusez vous à cracker le mot de passe retenu par le LLM. Essayez d'aller le plus loin possible. Quels sont les mécanismes pour hacker un LLM ?" ], "metadata": { "id": "5-hjrE22WTVZ" } }, { "cell_type": "markdown", "source": [ "## Comment détecter la toxicité en IA?\n", "Dans cette partie, nous allons nous interesser à comment détecter le langage toxique et les menaces potentielles. Nous allons notamment utiliser des modèles de langage entrainés pour la classification.\n", "Mais avant ça, importez les différentes bibliothèques et les différentes fonctions qui nous seront utiles pour la suite." ], "metadata": { "id": "KHVRWHMsWTVZ" } }, { "cell_type": "code", "execution_count": 3, "source": [ "!pip install transformers\n", "!pip install torch\n", "!pip install detoxify\n", "!pip install datasets\n", "!pip install scikit-learn\n", "!pip install evaluate" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.41.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.14.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (4.12.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.14.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.20.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.5.40)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", "Requirement already satisfied: detoxify in /usr/local/lib/python3.10/dist-packages (0.5.2)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from detoxify) (4.41.1)\n", "Requirement already satisfied: torch>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from detoxify) (2.3.0+cu121)\n", "Requirement already satisfied: sentencepiece>=0.1.94 in /usr/local/lib/python3.10/dist-packages (from detoxify) (0.1.99)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (3.14.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (4.12.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (1.12.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (3.1.4)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (2.20.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (2.3.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.7.0->detoxify) (12.5.40)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (0.23.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (2024.5.15)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (0.19.1)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (0.4.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (4.66.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.7.0->detoxify) (2.1.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (2024.2.2)\n", "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.7.0->detoxify) (1.3.0)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.19.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.14.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.25.2)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (14.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.0.3)\n", "Requirement already satisfied: requests>=2.32.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.4)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec[http]<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.5)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.2)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.2->datasets) (4.12.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.1->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.1->datasets) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.1->datasets) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.1->datasets) (2024.2.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.4)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.2.2)\n", "Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.25.2)\n", "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.11.4)\n", "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n", "Requirement already satisfied: evaluate in /usr/local/lib/python3.10/dist-packages (0.4.2)\n", "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (2.19.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from evaluate) (1.25.2)\n", "Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from evaluate) (2.0.3)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from evaluate) (4.66.4)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from evaluate) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.70.16)\n", "Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (2023.6.0)\n", "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.23.2)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from evaluate) (23.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (3.14.0)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (14.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (0.6)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (3.9.5)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->evaluate) (6.0.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.12.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->evaluate) (2024.2.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2023.4)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->evaluate) (2024.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.3)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.16.0)\n" ] } ], "metadata": { "id": "KVxXyw8LWTVa", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "3f7ddefa-4324-4e19-91b7-0ba30d5460bb" } }, { "cell_type": "code", "execution_count": 4, "source": [ "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import evaluate\n", "\n", "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n", "from detoxify import Detoxify\n" ], "outputs": [], "metadata": { "id": "B-J9XfgGWTVb" } }, { "cell_type": "markdown", "source": [ "Importer le dataset JIGSAW disponible sur le gitlab.\n", "\n", "Le dataset Jigsaw est un ensemble de données créé par Jigsaw et l'équipe Counter-Abuse Technology de Google, qui contient une vaste collection de commentaires en ligne provenant de divers sites d'actualités. Ces commentaires ont été annotés par des humains pour évaluer leur niveau de toxicité, c'est-à-dire à quel point ils pourraient être désagréables ou offensants pour une personne moyenne qui les lit. Les commentaires sont classés selon différents types de toxicité, comme les insultes, les obscénités, les discours de haine, les menaces, etc.\n", "\n", "Tout au long de ce TP, nous allons l'utiliser pour évaluer les algorithmes de classification. Ce dataset est donc composé de 50 commentaires. Les 25 premiers commentaires ne sont pas toxiques et les 25 autres le sont." ], "metadata": { "id": "_ticzaL1WTVd" } }, { "cell_type": "code", "execution_count": 5, "source": [ "dataset = pd.read_csv('./JIGSAW.csv')" ], "outputs": [], "metadata": { "id": "I6ivHkaWs-gB" } }, { "cell_type": "markdown", "source": [ "Mainenant, nous allons essayer de prédire la toxicitité d'une phrase du dataset JIGSAW. Pour cela, nous allons utiliser la bibliothèque Detoxify.\n", "Detoxify est une bibliothèque Python qui fournie des modèles pré-entraînés pour la détection de la toxicité dans les textes. Ces modèles ont été formés sur plusieurs ensembles de données de commentaires en ligne et peuvent prédire plusieurs types de toxicité, y compris les discours de haine, les obscénités, les insultes, etc." ], "metadata": { "id": "516aAOqUuHj4" } }, { "cell_type": "code", "execution_count": 6, "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "def check_toxicity(text):\n", " results = Detoxify('original',device=device).predict(text)\n", " return results\n", "\n", "toxicity_results = check_toxicity(dataset['comment_text'][25])\n", "print('Toxicity results:\\n', toxicity_results)" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Toxicity results:\n", " {'toxicity': 0.8165003, 'severe_toxicity': 0.0019058662, 'obscene': 0.06713744, 'threat': 0.003179473, 'insult': 0.2952709, 'identity_attack': 0.008143629}\n" ] } ], "metadata": { "id": "P_sYDfW4WTVd", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "91651a11-7a57-404e-c0cb-55c7c6bc017f" } }, { "cell_type": "markdown", "source": [ "Créez une fonction qui prend en entrée le dictionnaire issu de la fonction check_toxicity qui contient des labels et les scores de toxicité , et qui permet de visualiser ces scores sous forme de diagramme à barres horizontales.\n", "\n" ], "metadata": { "id": "qQHFbqjZwZ89" } }, { "cell_type": "code", "execution_count": 9, "source": [ "def visualize_toxicity(results):\n", " import matplotlib.pyplot as plt\n", " plt.clf()\n", " plt.barh(list(results.keys()),[results[key] for key in results])\n", " plt.show()\n", "visualize_toxicity(toxicity_results)" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "metadata": { "id": "nCIqrkixWTVe", "outputId": "0a7a1644-ec43-452f-e204-e5ac9dbed514", "colab": { "base_uri": "https://localhost:8080/", "height": 430 } } }, { "cell_type": "markdown", "source": [ "Ces scores reflettent -t-il bien les labels des phrases présents dans JIGSAW? (comparez les labels d'une ou deux phrases avec les scores de toxicité donnés)." ], "metadata": { "id": "3Wj6xXyTxDC1" } }, { "cell_type": "markdown", "source": [ "Pour savoir si Detoxify est un bon algorithme de classification, nous allons nous intéresser à la courbe AUC PR(Area Under the Curve - Precision Recall). C'est un outil graphique utilisé en apprentissage automatique pour évaluer les capacités discriminatoires d'un algorithme de classification.\n", "\n", "La courbe AUC PR est construite en traçant la précision (Precision) en fonction du rappel (Recall) à différents seuils de classification.\n", "\n", "La précision est la proportion de vrais positifs parmi tous les exemples classés comme positifs, tandis que le rappel (également appelé sensibilité) est la proportion de vrais positifs parmi tous les exemples réellement positifs.\n", "\n", "L'aire sous la courbe AUC PR (AUC pour Area Under the Curve) donne une mesure unique de la performance du modèle qui résume la qualité de la précision et du rappel pour tous les seuils possibles. Un AUC de 1.0 indique une performance parfaite, tandis qu'un AUC de 0.5 indique une performance équivalente à une classification aléatoire.\n", "\n", "N'hesitez pas à aller faire un tour sur ce site pour bien comprendre de quoi il s'agit 😉:\n", "https://kobia.fr/classification-metrics-precision-recall/\n", "\n", "Comme vous l'aurez sans doute compris, nous allons tracer la courbe AUC-PR de Detoxify pour voir à quel point le model est performant.🤯\n" ], "metadata": { "id": "_SOXwaJ0xzzW" } }, { "cell_type": "markdown", "source": [ "Dans un premier temps, pour tous les commentaires de JIGSAW, calculez le score de toxicité et stocker ces scores dans un array.(on ne s'interessera qu'au label toxicity)\n", "Dans un second temps, stocker dans un autre array les labels \"toxic\" des commentaire de Jigsaw." ], "metadata": { "id": "UDgMMhdr2_ic" } }, { "cell_type": "code", "execution_count": 10, "source": [ "SCORE = []\n", "labels = []\n", "i = 0\n", "for comment in dataset['comment_text']:\n", " toxicity_result = check_toxicity(comment)['toxicity']\n", " SCORE.append(toxicity_result)\n", " if toxicity_result > 0.5 :\n", " labels.append(1)\n", " else :\n", " labels.append(0)\n", "print(labels)\n" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]\n" ] } ], "metadata": { "id": "MODrXeMu2-5e", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "3f023ddd-14a3-4d52-9a06-7de537f3ab8b" } }, { "cell_type": "markdown", "source": [ "En utilisant la librairie scikit-learn, tracer AUC PR de Detoxify. Que vaut l'air sous cette courbe ? Concluez.\n" ], "metadata": { "id": "dlv5bQjj4ADM" } }, { "cell_type": "code", "execution_count": 11, "source": [ "!pip install detoxify\n", "\n", "\n", "import numpy as np\n", "from detoxify import Detoxify\n", "from sklearn.metrics import precision_recall_curve, auc\n", "import matplotlib.pyplot as plt\n", "\n", "# Calculer la courbe PR et l'aire sous la courbe\n", "precision, recall, _ = precision_recall_curve([0 for _ in range(25)]+[1 for i in range(25)], SCORE)\n", "auc_pr = auc(recall, precision)\n", "\n", "# Tracer la courbe PR\n", "plt.figure()\n", "plt.plot(recall, precision, label=f'AUC PR = {auc_pr:.2f}')\n", "plt.xlabel('Recall')\n", "plt.ylabel('Precision')\n", "plt.title('Precision-Recall Curve')\n", "plt.legend(loc='best')\n", "plt.show()\n", "\n", "print(f'L\\'aire sous la courbe PR (AUC PR) est de : {auc_pr:.2f}')\n" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: detoxify in /usr/local/lib/python3.10/dist-packages (0.5.2)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from detoxify) (4.41.1)\n", "Requirement already satisfied: torch>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from detoxify) (2.3.0+cu121)\n", "Requirement already satisfied: sentencepiece>=0.1.94 in /usr/local/lib/python3.10/dist-packages (from detoxify) (0.1.99)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (3.14.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (4.12.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (1.12.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (3.1.4)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (2.20.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (12.1.105)\n", "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.7.0->detoxify) (2.3.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.7.0->detoxify) (12.5.40)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (0.23.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (2024.5.15)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (0.19.1)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (0.4.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers->detoxify) (4.66.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.7.0->detoxify) (2.1.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->detoxify) (2024.2.2)\n", "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.7.0->detoxify) (1.3.0)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "L'aire sous la courbe PR (AUC PR) est de : 0.99\n" ] } ], "metadata": { "id": "vcps_UPP7-5y", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "outputId": "df141f02-e090-41af-cde8-b9dc1a08d600" } }, { "cell_type": "markdown", "source": [ "Ci-dessous, nous importons deux autres modèles de détection de toxicité. Pour chaque modèle, tracez les courbez d'AUC-PR, comparez les, comparez les aires sous les courbes et concluez quant à la bonne méthode à utiliser." ], "metadata": { "id": "gl4ifVPk9KEy" } }, { "cell_type": "code", "execution_count": 7, "source": [ "from transformers import pipeline\n", "\n", "toxic_rob = pipeline(\n", " \"text-classification\", model=\"s-nlp/roberta_toxicity_classifier\", device=device\n", " )\n", "toxic_r4 = evaluate.load(\"toxicity\")\n", "\n", "test = 'Fuck you'\n", "print(toxic_r4.compute(predictions=[test], aggregation=None),toxic_rob(test))" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "WARNING:evaluate_modules.metrics.evaluate-measurement--toxicity.2390290fa0bf6d78480143547c6b08f3d4f8805b249df8c7a8e80d0ce8e3778b.toxicity:Using default facebook/roberta-hate-speech-dynabench-r4-target checkpoint\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "{'toxicity': [0.0027796211652457714]} [{'label': 'toxic', 'score': 0.9995612502098083}]\n" ] } ], "metadata": { "id": "v5Z6lsn89IVt", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "15d66dc6-fbbd-462a-91c2-07e3f3b00018" } }, { "cell_type": "markdown", "source": [ "##Peut-on détécter la toxicité avec un LLM ?\n", "\n", "\n", "Dans cette partie , nous allons explorer l'application des modèles de langage, ou large Language Models (LLM), à la détection de la toxicité dans les textes. Pour cela exécuter la cellule suivante. Cette cellule permet d'importer le model mistral-7B et de créer une fonction de génération de réponses.\n" ], "metadata": { "id": "FUmJM-Lm-pnT" } }, { "cell_type": "code", "execution_count": 8, "source": [ "!pip install langchain\n", "!pip install langchain_community\n", "!pip install huggingface-cli\n", "!pip install llama-cpp-python\n", "!huggingface-cli download TheBloke/Mistral-7B-Instruct-v0.2-GGUF mistral-7b-instruct-v0.2.Q4_K_M.gguf --local-dir . --local-dir-use-symlinks False\n", "\n", "from langchain.llms import LlamaCpp\n", "from langchain.chains import LLMChain\n", "from langchain.prompts import PromptTemplate\n", "from langchain_core.output_parsers import StrOutputParser\n", "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "def create_chain(model_path):\n", " llm = LlamaCpp(\n", " model_path=model_path, stop=[\"Question\"], max_tokens=10, temperature=0,\n", "\t\t\t\tn_ctx=8000, n_batch=1024, n_gpu_layers=-1,\n", " )\n", " template = \"\"\"Query:{query}\n", " {prompt}\n", " Answer yes/no:\n", " \"\"\"\n", " prompt = PromptTemplate(\n", " input_variables=[\"instruction\",'chat_history'], template=template\n", " )\n", "\n", " llm_chain = prompt |llm| StrOutputParser()\n", " return llm_chain\n", "\n", "llm = create_chain('mistral-7b-instruct-v0.2.Q4_K_M.gguf')\n", "\n", "def gen(query,prompt):\n", " response = llm.invoke({\"query\":query,\"prompt\":prompt})\n", " return response\n", "\n" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: langchain in /usr/local/lib/python3.10/dist-packages (0.2.2)\n", "Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (6.0.1)\n", "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.0.30)\n", "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (3.9.5)\n", "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.3)\n", "Requirement already satisfied: langchain-core<0.3.0,>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (0.2.4)\n", "Requirement already satisfied: langchain-text-splitters<0.3.0,>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (0.2.1)\n", "Requirement already satisfied: langsmith<0.2.0,>=0.1.17 in /usr/local/lib/python3.10/dist-packages (from langchain) (0.1.71)\n", "Requirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.25.2)\n", "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.7.2)\n", "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.32.3)\n", "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (8.3.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n", "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.3.0,>=0.2.0->langchain) (1.33)\n", "Requirement already satisfied: packaging<24.0,>=23.2 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.3.0,>=0.2.0->langchain) (23.2)\n", "Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /usr/local/lib/python3.10/dist-packages (from langsmith<0.2.0,>=0.1.17->langchain) (3.10.3)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (0.7.0)\n", "Requirement already satisfied: pydantic-core==2.18.3 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (2.18.3)\n", "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (4.12.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2024.2.2)\n", "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain) (3.0.3)\n", "Requirement already satisfied: jsonpointer>=1.9 in /usr/local/lib/python3.10/dist-packages (from jsonpatch<2.0,>=1.33->langchain-core<0.3.0,>=0.2.0->langchain) (2.4)\n", "Requirement already satisfied: langchain_community in /usr/local/lib/python3.10/dist-packages (0.2.2)\n", "Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (6.0.1)\n", "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.0.30)\n", "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (3.9.5)\n", "Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (0.6.6)\n", "Requirement already satisfied: langchain<0.3.0,>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (0.2.2)\n", "Requirement already satisfied: langchain-core<0.3.0,>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (0.2.4)\n", "Requirement already satisfied: langsmith<0.2.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (0.1.71)\n", "Requirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (1.25.2)\n", "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.32.3)\n", "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (8.3.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (4.0.3)\n", "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /usr/local/lib/python3.10/dist-packages (from dataclasses-json<0.7,>=0.5.7->langchain_community) (3.21.2)\n", "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from dataclasses-json<0.7,>=0.5.7->langchain_community) (0.9.0)\n", "Requirement already satisfied: langchain-text-splitters<0.3.0,>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from langchain<0.3.0,>=0.2.0->langchain_community) (0.2.1)\n", "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain<0.3.0,>=0.2.0->langchain_community) (2.7.2)\n", "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.3.0,>=0.2.0->langchain_community) (1.33)\n", "Requirement already satisfied: packaging<24.0,>=23.2 in /usr/local/lib/python3.10/dist-packages (from langchain-core<0.3.0,>=0.2.0->langchain_community) (23.2)\n", "Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /usr/local/lib/python3.10/dist-packages (from langsmith<0.2.0,>=0.1.0->langchain_community) (3.10.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2024.2.2)\n", "Requirement already satisfied: typing-extensions>=4.6.0 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain_community) (4.12.0)\n", "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain_community) (3.0.3)\n", "Requirement already satisfied: jsonpointer>=1.9 in /usr/local/lib/python3.10/dist-packages (from jsonpatch<2.0,>=1.33->langchain-core<0.3.0,>=0.2.0->langchain_community) (2.4)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain<0.3.0,>=0.2.0->langchain_community) (0.7.0)\n", "Requirement already satisfied: pydantic-core==2.18.3 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain<0.3.0,>=0.2.0->langchain_community) (2.18.3)\n", "Requirement already satisfied: mypy-extensions>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain_community) (1.0.0)\n", "Requirement already satisfied: huggingface-cli in /usr/local/lib/python3.10/dist-packages (0.1)\n", "Requirement already satisfied: llama-cpp-python in /usr/local/lib/python3.10/dist-packages (0.2.77)\n", "Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-cpp-python) (4.12.0)\n", "Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.10/dist-packages (from llama-cpp-python) (1.25.2)\n", "Requirement already satisfied: diskcache>=5.6.1 in /usr/local/lib/python3.10/dist-packages (from llama-cpp-python) (5.6.3)\n", "Requirement already satisfied: jinja2>=2.11.3 in /usr/local/lib/python3.10/dist-packages (from llama-cpp-python) (3.1.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2>=2.11.3->llama-cpp-python) (2.1.5)\n", "/usr/local/lib/python3.10/dist-packages/huggingface_hub/commands/download.py:132: FutureWarning: Ignoring --local-dir-use-symlinks. Downloading to a local directory does not use symlinks anymore.\n", " warnings.warn(\n", "mistral-7b-instruct-v0.2.Q4_K_M.gguf\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "llama_model_loader: loaded meta data with 24 key-value pairs and 291 tensors from mistral-7b-instruct-v0.2.Q4_K_M.gguf (version GGUF V3 (latest))\n", "llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", "llama_model_loader: - kv 0: general.architecture str = llama\n", "llama_model_loader: - kv 1: general.name str = mistralai_mistral-7b-instruct-v0.2\n", "llama_model_loader: - kv 2: llama.context_length u32 = 32768\n", "llama_model_loader: - kv 3: llama.embedding_length u32 = 4096\n", "llama_model_loader: - kv 4: llama.block_count u32 = 32\n", "llama_model_loader: - kv 5: llama.feed_forward_length u32 = 14336\n", "llama_model_loader: - kv 6: llama.rope.dimension_count u32 = 128\n", "llama_model_loader: - kv 7: llama.attention.head_count u32 = 32\n", "llama_model_loader: - kv 8: llama.attention.head_count_kv u32 = 8\n", "llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 = 0.000010\n", "llama_model_loader: - kv 10: llama.rope.freq_base f32 = 1000000.000000\n", "llama_model_loader: - kv 11: general.file_type u32 = 15\n", "llama_model_loader: - kv 12: tokenizer.ggml.model str = llama\n", "llama_model_loader: - kv 13: tokenizer.ggml.tokens arr[str,32000] = [\"\", \"\", \"\", \"<0x00>\", \"<...\n", "llama_model_loader: - kv 14: tokenizer.ggml.scores arr[f32,32000] = [0.000000, 0.000000, 0.000000, 0.0000...\n", "llama_model_loader: - kv 15: tokenizer.ggml.token_type arr[i32,32000] = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...\n", "llama_model_loader: - kv 16: tokenizer.ggml.bos_token_id u32 = 1\n", "llama_model_loader: - kv 17: tokenizer.ggml.eos_token_id u32 = 2\n", "llama_model_loader: - kv 18: tokenizer.ggml.unknown_token_id u32 = 0\n", "llama_model_loader: - kv 19: tokenizer.ggml.padding_token_id u32 = 0\n", "llama_model_loader: - kv 20: tokenizer.ggml.add_bos_token bool = true\n", "llama_model_loader: - kv 21: tokenizer.ggml.add_eos_token bool = false\n", "llama_model_loader: - kv 22: tokenizer.chat_template str = {{ bos_token }}{% for message in mess...\n", "llama_model_loader: - kv 23: general.quantization_version u32 = 2\n", "llama_model_loader: - type f32: 65 tensors\n", "llama_model_loader: - type q4_K: 193 tensors\n", "llama_model_loader: - type q6_K: 33 tensors\n", "llm_load_vocab: special tokens cache size = 259\n", "llm_load_vocab: token to piece cache size = 0.1637 MB\n", "llm_load_print_meta: format = GGUF V3 (latest)\n", "llm_load_print_meta: arch = llama\n", "llm_load_print_meta: vocab type = SPM\n", "llm_load_print_meta: n_vocab = 32000\n", "llm_load_print_meta: n_merges = 0\n", "llm_load_print_meta: n_ctx_train = 32768\n", "llm_load_print_meta: n_embd = 4096\n", "llm_load_print_meta: n_head = 32\n", "llm_load_print_meta: n_head_kv = 8\n", "llm_load_print_meta: n_layer = 32\n", "llm_load_print_meta: n_rot = 128\n", "llm_load_print_meta: n_embd_head_k = 128\n", "llm_load_print_meta: n_embd_head_v = 128\n", "llm_load_print_meta: n_gqa = 4\n", "llm_load_print_meta: n_embd_k_gqa = 1024\n", "llm_load_print_meta: n_embd_v_gqa = 1024\n", "llm_load_print_meta: f_norm_eps = 0.0e+00\n", "llm_load_print_meta: f_norm_rms_eps = 1.0e-05\n", "llm_load_print_meta: f_clamp_kqv = 0.0e+00\n", "llm_load_print_meta: f_max_alibi_bias = 0.0e+00\n", "llm_load_print_meta: f_logit_scale = 0.0e+00\n", "llm_load_print_meta: n_ff = 14336\n", "llm_load_print_meta: n_expert = 0\n", "llm_load_print_meta: n_expert_used = 0\n", "llm_load_print_meta: causal attn = 1\n", "llm_load_print_meta: pooling type = 0\n", "llm_load_print_meta: rope type = 0\n", "llm_load_print_meta: rope scaling = linear\n", "llm_load_print_meta: freq_base_train = 1000000.0\n", "llm_load_print_meta: freq_scale_train = 1\n", "llm_load_print_meta: n_yarn_orig_ctx = 32768\n", "llm_load_print_meta: rope_finetuned = unknown\n", "llm_load_print_meta: ssm_d_conv = 0\n", "llm_load_print_meta: ssm_d_inner = 0\n", "llm_load_print_meta: ssm_d_state = 0\n", "llm_load_print_meta: ssm_dt_rank = 0\n", "llm_load_print_meta: model type = 7B\n", "llm_load_print_meta: model ftype = Q4_K - Medium\n", "llm_load_print_meta: model params = 7.24 B\n", "llm_load_print_meta: model size = 4.07 GiB (4.83 BPW) \n", "llm_load_print_meta: general.name = mistralai_mistral-7b-instruct-v0.2\n", "llm_load_print_meta: BOS token = 1 ''\n", "llm_load_print_meta: EOS token = 2 ''\n", "llm_load_print_meta: UNK token = 0 ''\n", "llm_load_print_meta: PAD token = 0 ''\n", "llm_load_print_meta: LF token = 13 '<0x0A>'\n", "llm_load_tensors: ggml ctx size = 0.15 MiB\n", "llm_load_tensors: CPU buffer size = 4165.37 MiB\n", ".................................................................................................\n", "llama_new_context_with_model: n_ctx = 8000\n", "llama_new_context_with_model: n_batch = 1024\n", "llama_new_context_with_model: n_ubatch = 512\n", "llama_new_context_with_model: flash_attn = 0\n", "llama_new_context_with_model: freq_base = 10000.0\n", "llama_new_context_with_model: freq_scale = 1\n", "llama_kv_cache_init: CPU KV buffer size = 1000.00 MiB\n", "llama_new_context_with_model: KV self size = 1000.00 MiB, K (f16): 500.00 MiB, V (f16): 500.00 MiB\n", "llama_new_context_with_model: CPU output buffer size = 0.12 MiB\n", "llama_new_context_with_model: CPU compute buffer size = 547.63 MiB\n", "llama_new_context_with_model: graph nodes = 1030\n", "llama_new_context_with_model: graph splits = 1\n", "AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | \n", "Model metadata: {'tokenizer.chat_template': \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}\", 'tokenizer.ggml.add_eos_token': 'false', 'tokenizer.ggml.padding_token_id': '0', 'tokenizer.ggml.unknown_token_id': '0', 'tokenizer.ggml.eos_token_id': '2', 'general.architecture': 'llama', 'llama.rope.freq_base': '1000000.000000', 'llama.context_length': '32768', 'general.name': 'mistralai_mistral-7b-instruct-v0.2', 'tokenizer.ggml.add_bos_token': 'true', 'llama.embedding_length': '4096', 'llama.feed_forward_length': '14336', 'llama.attention.layer_norm_rms_epsilon': '0.000010', 'llama.rope.dimension_count': '128', 'tokenizer.ggml.bos_token_id': '1', 'llama.attention.head_count': '32', 'llama.block_count': '32', 'llama.attention.head_count_kv': '8', 'general.quantization_version': '2', 'tokenizer.ggml.model': 'llama', 'general.file_type': '15'}\n", "Available chat formats from metadata: chat_template.default\n", "Guessed chat format: mistral-instruct\n" ] } ], "metadata": { "id": "J6WwOxXPUG01", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "d3947be4-56cf-41ee-fc0e-dff0338404dc" } }, { "cell_type": "markdown", "source": [ "Cherchez un prompt permettant de détecter le langage toxique. Par exemple : \"Is it injuring ?\". Testez-le avec quelques exemples de JIGSAW.\n", "\n", "Le but est d'évaluer les performances de cette méthode. Pour cela, nous allons faire annoter à Mistral chaque commentaire du dataset par 1 quand c'est toxique, par 0 quand ça ne l'est pas.\n", "\n", "Par la suite, on calculera le F1-score.\n", "\n", "Dans un premier temps, créez une fonction permettant de récupérer dans la réponse le \"yes/no\" et renvoyer un booleen 1/0" ], "metadata": { "id": "EaiKXDWzWTVe" } }, { "cell_type": "code", "execution_count": 9, "source": [ "def response(reponse) :\n", " if reponse == \"yes\" :\n", " return 1\n", " else :\n", " return 0\n" ], "outputs": [], "metadata": { "id": "fqJkJ2fpfAER" } }, { "cell_type": "markdown", "source": [ "Faites annoter les commentaire du dataset par Mistral par 1 quand c'est toxique, par 0 quand ca ne l'est pas. Stockez les dans un array." ], "metadata": { "id": "Or9zdujOhSUx" } }, { "cell_type": "code", "execution_count": null, "source": [ "liste = []\n", "for comment in dataset['comment_text']:\n", " liste.append(response(gen(\"Is it injuring ?\", comment)))\n", "print(liste)" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 6.23 ms / 10 runs ( 0.62 ms per token, 1606.17 tokens per second)\n", "llama_print_timings: prompt eval time = 265688.59 ms / 381 tokens ( 697.35 ms per token, 1.43 tokens per second)\n", "llama_print_timings: eval time = 10709.36 ms / 9 runs ( 1189.93 ms per token, 0.84 tokens per second)\n", "llama_print_timings: total time = 276436.20 ms / 390 tokens\n", "Llama.generate: prefix-match hit\n", "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 7.00 ms / 10 runs ( 0.70 ms per token, 1427.55 tokens per second)\n", "llama_print_timings: prompt eval time = 24281.34 ms / 33 tokens ( 735.80 ms per token, 1.36 tokens per second)\n", "llama_print_timings: eval time = 10438.11 ms / 9 runs ( 1159.79 ms per token, 0.86 tokens per second)\n", "llama_print_timings: total time = 34752.53 ms / 42 tokens\n", "Llama.generate: prefix-match hit\n", "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 6.92 ms / 10 runs ( 0.69 ms per token, 1445.50 tokens per second)\n", "llama_print_timings: prompt eval time = 8957.96 ms / 21 tokens ( 426.57 ms per token, 2.34 tokens per second)\n", "llama_print_timings: eval time = 7521.59 ms / 9 runs ( 835.73 ms per token, 1.20 tokens per second)\n", "llama_print_timings: total time = 16500.05 ms / 30 tokens\n", "Llama.generate: prefix-match hit\n", "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 6.08 ms / 10 runs ( 0.61 ms per token, 1643.39 tokens per second)\n", "llama_print_timings: prompt eval time = 16266.95 ms / 36 tokens ( 451.86 ms per token, 2.21 tokens per second)\n", "llama_print_timings: eval time = 7736.29 ms / 9 runs ( 859.59 ms per token, 1.16 tokens per second)\n", "llama_print_timings: total time = 24020.95 ms / 45 tokens\n", "Llama.generate: prefix-match hit\n", "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 6.68 ms / 10 runs ( 0.67 ms per token, 1497.01 tokens per second)\n", "llama_print_timings: prompt eval time = 34243.51 ms / 77 tokens ( 444.72 ms per token, 2.25 tokens per second)\n", "llama_print_timings: eval time = 7637.30 ms / 9 runs ( 848.59 ms per token, 1.18 tokens per second)\n", "llama_print_timings: total time = 41899.72 ms / 86 tokens\n", "Llama.generate: prefix-match hit\n", "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 7.36 ms / 10 runs ( 0.74 ms per token, 1358.70 tokens per second)\n", "llama_print_timings: prompt eval time = 32143.66 ms / 66 tokens ( 487.03 ms per token, 2.05 tokens per second)\n", "llama_print_timings: eval time = 9700.41 ms / 9 runs ( 1077.82 ms per token, 0.93 tokens per second)\n", "llama_print_timings: total time = 41864.61 ms / 75 tokens\n", "Llama.generate: prefix-match hit\n", "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 6.86 ms / 10 runs ( 0.69 ms per token, 1457.94 tokens per second)\n", "llama_print_timings: prompt eval time = 32399.87 ms / 69 tokens ( 469.56 ms per token, 2.13 tokens per second)\n", "llama_print_timings: eval time = 6854.11 ms / 9 runs ( 761.57 ms per token, 1.31 tokens per second)\n", "llama_print_timings: total time = 39270.58 ms / 78 tokens\n", "Llama.generate: prefix-match hit\n", "\n", "llama_print_timings: load time = 265697.83 ms\n", "llama_print_timings: sample time = 6.68 ms / 10 runs ( 0.67 ms per token, 1496.33 tokens per second)\n", "llama_print_timings: prompt eval time = 18102.71 ms / 30 tokens ( 603.42 ms per token, 1.66 tokens per second)\n", "llama_print_timings: eval time = 11512.61 ms / 9 runs ( 1279.18 ms per token, 0.78 tokens per second)\n", "llama_print_timings: total time = 29641.54 ms / 39 tokens\n", "Llama.generate: prefix-match hit\n" ] } ], "metadata": { "id": "kdLS9fmigCKR", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "8f9bb0b9-20d1-4d65-b703-c44a9b93c6e1" } }, { "cell_type": "markdown", "source": [ "Calculez le F1-Score et concluez sur l'efficacité de la méthode." ], "metadata": { "id": "UOl2Q138h1RG" } }, { "cell_type": "code", "execution_count": null, "source": [ "labels_test = [0 for _ in range(25)] + [1 for _ in range(25)]\n", "predicted_labels = liste\n", "from sklearn.metrics import f1_score\n", "f1 = f1_score(labels_test, predicted_labels)\n", "print(f1)\n" ], "outputs": [], "metadata": { "id": "yrJ60HN0hqcy" } }, { "cell_type": "markdown", "source": [ "Pour cloturer ce TP, créez un filtre de toxicité sur le RAG que vous avez implémenté durant les dernières séances avec la méthode de votre choix. Ce filtre devra filtrer le langage toxique de l'utilisateur." ], "metadata": { "id": "6kmocZvpiMNg" } }, { "cell_type": "code", "execution_count": null, "source": [ "from detoxify import Detoxify\n", "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "\n", "# Initialisation du modèle Detoxify\n", "detoxify_model = Detoxify('original')\n", "\n", "# Initialisation du modèle RAG (assurez-vous de charger votre propre modèle RAG)\n", "rag_model_name = \"facebook/rag-token-base\"\n", "tokenizer = AutoTokenizer.from_pretrained(rag_model_name)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(rag_model_name)\n", "\n", "# Fonction pour détecter la toxicité\n", "def is_toxic(text):\n", " predictions = detoxify_model.predict(text)\n", " return predictions['toxicity'] > 0.5 # Seuil de toxicité (peut être ajusté)\n", "\n", "# Fonction pour filtrer les messages\n", "def filter_message(message):\n", " if is_toxic(message):\n", " return \"Votre message a été détecté comme toxique et ne sera pas transmis.\"\n", " return message\n", "\n", "# Exemple d'utilisation avec RAG\n", "def rag_response(input_text):\n", " # Filtrage du message\n", " filtered_message = filter_message(input_text)\n", " if \"toxique\" in filtered_message:\n", " return filtered_message\n", "\n", " # Préparation de l'entrée pour le modèle RAG\n", " inputs = tokenizer([filtered_message], return_tensors='pt')\n", "\n", " # Génération de la réponse avec RAG\n", " outputs = model.generate(**inputs)\n", " response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", " return response\n", "\n", "# Exemple d'interaction\n", "user_message = input(\"Entrez votre message: \")\n", "response = rag_response(user_message)\n", "print(\"Réponse du système:\", response)\n" ], "outputs": [], "metadata": { "id": "0xgyS61riIHj" } }, { "cell_type": "markdown", "source": [ "#Bonus:\n", "\n", "Determinez avec la méthode de votre choix le seuil optimal pour filtrer un maximum de langage toxique avec Roberta." ], "metadata": { "id": "qoWo-wrpirwm" } }, { "cell_type": "code", "execution_count": null, "source": [], "outputs": [], "metadata": { "id": "e8pXYU2EjSYH" } } ], "metadata": { "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "colab": { "provenance": [] }, "interpreter": { "hash": "1a85d0143f93c4f67f394d062034989b2c9fca60c1f84a050bdb6946630bae93" } }, "nbformat": 4, "nbformat_minor": 0 }