import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login import os from typing import List, Dict, Any import time # Configuration MODEL_ID = "facebook/MobileLLM-Pro" MAX_HISTORY_LENGTH = 10 MAX_NEW_TOKENS = 512 DEFAULT_SYSTEM_PROMPT = "You are a helpful, friendly, and intelligent assistant. Provide clear, accurate, and thoughtful responses." # Login to Hugging Face (if token is provided) HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: try: login(token=HF_TOKEN) print("Successfully logged in to Hugging Face") except Exception as e: print(f"Warning: Could not login to Hugging Face: {e}") class MobileLLMChat: def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False def load_model(self, version="instruct"): """Load the MobileLLM-Pro model and tokenizer""" try: print(f"Loading MobileLLM-Pro ({version})...") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, subfolder=version ) # Load model self.model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, subfolder=version, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): self.model.to(self.device) self.model.eval() self.model_loaded = True print(f"Model loaded successfully on {self.device}") return True except Exception as e: print(f"Error loading model: {e}") return False def format_chat_history(self, history: List[Dict[str, str]], system_prompt: str) -> List[Dict[str, str]]: """Format chat history for the model""" messages = [{"role": "system", "content": system_prompt}] for msg in history: if msg["role"] in ["user", "assistant"]: messages.append(msg) return messages def generate_response(self, user_input: str, history: List[Dict[str, str]], system_prompt: str, temperature: float = 0.7, max_new_tokens: int = MAX_NEW_TOKENS) -> str: """Generate a response from the model""" if not self.model_loaded: return "Model not loaded. Please try loading the model first." try: # Add user message to history history.append({"role": "user", "content": user_input}) # Format messages messages = self.format_chat_history(history, system_prompt) # Apply chat template inputs = self.tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(self.device) # Generate response with torch.no_grad(): outputs = self.model.generate( inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode response response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the new response (remove input) if response.startswith(messages[0]["content"]): response = response[len(messages[0]["content"]):].strip() # Remove the user input from the response if user_input in response: response = response.replace(user_input, "").strip() # Clean up common prefixes prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"] for prefix in prefixes_to_remove: if response.lower().startswith(prefix.lower()): response = response[len(prefix):].strip() # Add assistant response to history history.append({"role": "assistant", "content": response}) return response except Exception as e: return f"Error generating response: {str(e)}" def generate_stream(self, user_input: str, history: List[Dict[str, str]], system_prompt: str, temperature: float = 0.7): """Generate a streaming response from the model""" if not self.model_loaded: yield "Model not loaded. Please try loading the model first." return try: # Add user message to history history.append({"role": "user", "content": user_input}) # Format messages messages = self.format_chat_history(history, system_prompt) # Apply chat template inputs = self.tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(self.device) # Generate streaming response generated_text = "" for token_id in self.model.generate( inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, streamer=None, ): # Decode current token new_token = self.tokenizer.decode(token_id[-1:], skip_special_tokens=True) generated_text += new_token # Extract only the new response response = generated_text if response.startswith(messages[0]["content"]): response = response[len(messages[0]["content"]):].strip() if user_input in response: response = response.replace(user_input, "").strip() # Clean up common prefixes prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"] for prefix in prefixes_to_remove: if response.lower().startswith(prefix.lower()): response = response[len(prefix):].strip() yield response # Stop if we hit end of sentence if new_token in ["", "<|endoftext|>", "."] and len(response) > 50: break # Add final response to history history.append({"role": "assistant", "content": response}) except Exception as e: yield f"Error generating response: {str(e)}" # Initialize chat model chat_model = MobileLLMChat() def load_model_button(version): """Load the model when button is clicked""" success = chat_model.load_model(version) if success: return gr.update(visible=False), gr.update(visible=True), gr.update(value="Model loaded successfully!") else: return gr.update(visible=True), gr.update(visible=False), gr.update(value="Failed to load model. Please check the logs.") def clear_chat(): """Clear the chat history""" return [], [] def chat_fn(message, history, system_prompt, temperature, model_version): """Main chat function""" if not chat_model.model_loaded: return "Please load the model first using the button above." # Convert history format formatted_history = [] for user_msg, assistant_msg in history: formatted_history.append({"role": "user", "content": user_msg}) if assistant_msg: formatted_history.append({"role": "assistant", "content": assistant_msg}) # Generate response response = chat_model.generate_response(message, formatted_history, system_prompt, temperature) return response def chat_stream_fn(message, history, system_prompt, temperature, model_version): """Streaming chat function""" if not chat_model.model_loaded: yield "Please load the model first using the button above." return # Convert history format formatted_history = [] for user_msg, assistant_msg in history: formatted_history.append({"role": "user", "content": user_msg}) if assistant_msg: formatted_history.append({"role": "assistant", "content": assistant_msg}) # Generate streaming response for chunk in chat_model.generate_stream(message, formatted_history, system_prompt, temperature): yield chunk # Create the Gradio interface with gr.Blocks( title="MobileLLM-Pro Chat", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 900px !important; margin: auto !important; } .message { padding: 12px !important; border-radius: 8px !important; margin-bottom: 8px !important; } .user-message { background-color: #e3f2fd !important; margin-left: 20% !important; } .assistant-message { background-color: #f5f5f5 !important; margin-right: 20% !important; } """ ) as demo: # Header gr.HTML("""

🤖 MobileLLM-Pro Chat

Built with anycoder

Chat with Facebook's MobileLLM-Pro model optimized for on-device inference

""") # Model loading section with gr.Row(): with gr.Column(scale=1): model_version = gr.Dropdown( choices=["instruct", "base"], value="instruct", label="Model Version", info="Choose between instruct (chat) or base model" ) load_btn = gr.Button("🚀 Load Model", variant="primary", size="lg") with gr.Column(scale=2): model_status = gr.Textbox( label="Model Status", value="Model not loaded", interactive=False ) # Configuration section with gr.Accordion("⚙️ Configuration", open=False): with gr.Row(): system_prompt = gr.Textbox( value=DEFAULT_SYSTEM_PROMPT, label="System Prompt", lines=3, info="Customize the AI's behavior and personality" ) with gr.Row(): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Controls randomness (higher = more creative)" ) streaming = gr.Checkbox( value=True, label="Enable Streaming", info="Show responses as they're being generated" ) # Chat interface chatbot = gr.Chatbot( label="Chat History", height=500, show_copy_button=True, bubble_full_width=False, type="messages" ) with gr.Row(): msg = gr.Textbox( label="Your Message", placeholder="Type your message here...", scale=4, container=False ) submit_btn = gr.Button("Send", variant="primary", scale=1) clear_btn = gr.Button("Clear", scale=0) # Event handlers load_btn.click( load_model_button, inputs=[model_version], outputs=[load_btn, model_status, model_status] ) # Handle chat submission def handle_chat(message, history, system_prompt, temperature, model_version, streaming): if streaming: return chat_stream_fn(message, history, system_prompt, temperature, model_version) else: return chat_fn(message, history, system_prompt, temperature, model_version) msg.submit( handle_chat, inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming], outputs=[chatbot] ) submit_btn.click( handle_chat, inputs=[msg, chatbot, system_prompt, temperature, model_version, streaming], outputs=[chatbot] ) clear_btn.click( clear_chat, outputs=[chatbot, msg] ) # Examples gr.Examples( examples=[ ["What are the benefits of on-device AI models?"], ["Explain quantum computing in simple terms."], ["Write a short poem about technology."], ["What's the difference between machine learning and deep learning?"], ["How can I improve my productivity?"], ], inputs=[msg], label="Example Prompts" ) # Footer gr.HTML("""

⚠️ Note: This model requires significant computational resources. Loading may take a few minutes.

Model: facebook/MobileLLM-Pro

""") # Launch the app if __name__ == "__main__": demo.launch( share=True, show_error=True, show_tips=True, debug=True )