Text Generation
Transformers
Safetensors
English
gemma3_text
conversational
text-generation-inference
pankajmathur commited on
Commit
bb0c0fa
·
verified ·
1 Parent(s): 41949b5

Upload Mimma_3_1b_chat.ipynb

Browse files
Files changed (1) hide show
  1. Mimma_3_1b_chat.ipynb +197 -0
Mimma_3_1b_chat.ipynb ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "metadata": {
23
+ "id": "TFu_ibC1eYrz"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "!pip install torch transformers -q"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "source": [
33
+ "import torch\n",
34
+ "from transformers import pipeline\n",
35
+ "from IPython.display import clear_output\n",
36
+ "from google.colab import output"
37
+ ],
38
+ "metadata": {
39
+ "id": "Zs7QNs0Tet6r"
40
+ },
41
+ "execution_count": null,
42
+ "outputs": []
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "source": [
47
+ "class ChatBot:\n",
48
+ " _instance = None\n",
49
+ " _current_model = None\n",
50
+ "\n",
51
+ " def __init__(self, model_slug=None):\n",
52
+ " if model_slug and model_slug != ChatBot._current_model:\n",
53
+ " self.load_model(model_slug)\n",
54
+ " ChatBot._current_model = model_slug\n",
55
+ "\n",
56
+ " self.messages = []\n",
57
+ " self.max_tokens = 512\n",
58
+ " self.temperature = 0.01\n",
59
+ " self.top_k = 64\n",
60
+ " self.top_p = 0.95\n",
61
+ " self.min_p = 0.0\n",
62
+ " self.repetition_penalty = 1.3\n",
63
+ "\n",
64
+ " @classmethod\n",
65
+ " def get_instance(cls, model_slug=None):\n",
66
+ " if not cls._instance or (model_slug and model_slug != cls._current_model):\n",
67
+ " cls._instance = cls(model_slug)\n",
68
+ " return cls._instance\n",
69
+ "\n",
70
+ " def load_model(self, model_slug):\n",
71
+ " print(f\"Loading model {model_slug}...\")\n",
72
+ " self.pipeline = pipeline(\n",
73
+ " \"text-generation\",\n",
74
+ " model=model_slug,\n",
75
+ " device_map=\"cuda\",\n",
76
+ " torch_dtype=torch.bfloat16\n",
77
+ " )\n",
78
+ " clear_output()\n",
79
+ " print(\"Model loaded successfully!\")\n",
80
+ "\n",
81
+ " def reset_conversation(self, system_message):\n",
82
+ " \"\"\"Reset the conversation with a new system message\"\"\"\n",
83
+ " self.messages = [{\"role\": \"system\", \"content\": system_message}]\n",
84
+ "\n",
85
+ " def get_response(self, user_input):\n",
86
+ " \"\"\"Get response with current parameters\"\"\"\n",
87
+ " self.messages.append({\"role\": \"user\", \"content\": user_input})\n",
88
+ " outputs = self.pipeline(\n",
89
+ " self.messages,\n",
90
+ " max_new_tokens=self.max_tokens,\n",
91
+ " do_sample=True,\n",
92
+ " temperature=self.temperature,\n",
93
+ " top_k=self.top_k,\n",
94
+ " top_p=self.top_p,\n",
95
+ " min_p = self.min_p,\n",
96
+ " repetition_penalty = self.repetition_penalty\n",
97
+ " )\n",
98
+ " response = outputs[0][\"generated_text\"][-1]\n",
99
+ " content = response.get('content', 'No content available')\n",
100
+ " self.messages.append({\"role\": \"assistant\", \"content\": content})\n",
101
+ " return content\n",
102
+ "\n",
103
+ " def update_params(self, max_tokens=None, temperature=None, top_k=None, top_p=None, min_p=None, repetition_penalty=None):\n",
104
+ " \"\"\"Update generation parameters\"\"\"\n",
105
+ " if max_tokens is not None:\n",
106
+ " self.max_tokens = max_tokens\n",
107
+ " if temperature is not None:\n",
108
+ " self.temperature = temperature\n",
109
+ " if top_k is not None:\n",
110
+ " self.top_k = top_k\n",
111
+ " if top_p is not None:\n",
112
+ " self.top_p = top_p\n",
113
+ " if min_p is not None:\n",
114
+ " self.min_p = min_p\n",
115
+ " if repetition_penalty is not None:\n",
116
+ " self.repetition_penalty = repetition_penalty"
117
+ ],
118
+ "metadata": {
119
+ "id": "v4uIN6uIeyl3"
120
+ },
121
+ "execution_count": null,
122
+ "outputs": []
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "source": [
127
+ "def run_chatbot(\n",
128
+ " model=None,\n",
129
+ " system_message=\"You are Mimma, You are expert in python coding, Think step by step before coming up with final python code\",\n",
130
+ " max_tokens=None,\n",
131
+ " temperature=None,\n",
132
+ " top_k=None,\n",
133
+ " top_p=None,\n",
134
+ "):\n",
135
+ " try:\n",
136
+ " # Get or create chatbot instance\n",
137
+ " chatbot = ChatBot.get_instance(model)\n",
138
+ "\n",
139
+ " # Update parameters if provided\n",
140
+ " if max_tokens or temperature or top_k or top_p:\n",
141
+ " chatbot.update_params(max_tokens, temperature, top_k, top_p)\n",
142
+ "\n",
143
+ " # Reset conversation with new system message\n",
144
+ " if system_message:\n",
145
+ " chatbot.reset_conversation(system_message)\n",
146
+ "\n",
147
+ " print(\"Chatbot: Hi! Type 'quit' to exit.\")\n",
148
+ "\n",
149
+ " while True:\n",
150
+ " user_input = input(\"You: \").strip()\n",
151
+ " if user_input.lower() == 'quit':\n",
152
+ " break\n",
153
+ " try:\n",
154
+ " response = chatbot.get_response(user_input)\n",
155
+ " print(\"Chatbot:\", response)\n",
156
+ " except Exception as e:\n",
157
+ " print(f\"Chatbot: An error occurred: {str(e)}\")\n",
158
+ " print(\"Please try again.\")\n",
159
+ "\n",
160
+ " except Exception as e:\n",
161
+ " print(f\"Error in chatbot: {str(e)}\")"
162
+ ],
163
+ "metadata": {
164
+ "id": "H2n_6Xcue3Vn"
165
+ },
166
+ "execution_count": null,
167
+ "outputs": []
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "source": [
172
+ "run_chatbot(model=\"pankajmathur/Mimma-3-1b\")"
173
+ ],
174
+ "metadata": {
175
+ "id": "JEqgoAH2fC6h"
176
+ },
177
+ "execution_count": null,
178
+ "outputs": []
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "source": [
183
+ "# # change system message\n",
184
+ "# run_chatbot(\n",
185
+ "# system_message=\"You are Orca Mini, You are expert in logic, Think step by step before coming up with final answer\",\n",
186
+ "# max_tokens=1024,\n",
187
+ "# temperature=0.3\n",
188
+ "# )"
189
+ ],
190
+ "metadata": {
191
+ "id": "tGW8wsfAfHDf"
192
+ },
193
+ "execution_count": null,
194
+ "outputs": []
195
+ }
196
+ ]
197
+ }