import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login import os from typing import List, Dict, Any import time # Configuration MODEL_ID = "facebook/MobileLLM-Pro" MAX_HISTORY_LENGTH = 10 MAX_NEW_TOKENS = 512 DEFAULT_SYSTEM_PROMPT = "You are a helpful, friendly, and intelligent assistant. Provide clear, accurate, and thoughtful responses." # Login to Hugging Face (if token is provided) HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: try: login(token=HF_TOKEN) print("Successfully logged in to Hugging Face") except Exception as e: print(f"Warning: Could not login to Hugging Face: {e}") class MobileLLMChat: def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False def load_model(self, version="instruct"): """Load the MobileLLM-Pro model and tokenizer""" try: print(f"Loading MobileLLM-Pro ({version})...") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, subfolder=version ) # Load model self.model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, subfolder=version, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): self.model.to(self.device) self.model.eval() self.model_loaded = True print(f"Model loaded successfully on {self.device}") return True except Exception as e: print(f"Error loading model: {e}") return False def format_chat_history(self, history: List[Dict[str, str]], system_prompt: str) -> List[Dict[str, str]]: """Format chat history for the model""" messages = [{"role": "system", "content": system_prompt}] for msg in history: if msg["role"] in ["user", "assistant"]: messages.append(msg) return messages def generate_response(self, user_input: str, history: List[Dict[str, str]], system_prompt: str, temperature: float = 0.7, max_new_tokens: int = MAX_NEW_TOKENS) -> str: """Generate a response from the model""" if not self.model_loaded: return "Model not loaded. Please try loading the model first." try: # Add user message to history history.append({"role": "user", "content": user_input}) # Format messages messages = self.format_chat_history(history, system_prompt) # Apply chat template inputs = self.tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(self.device) # Generate response with torch.no_grad(): outputs = self.model.generate( inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode response response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the new response (remove input) if response.startswith(messages[0]["content"]): response = response[len(messages[0]["content"]):].strip() # Remove the user input from the response if user_input in response: response = response.replace(user_input, "").strip() # Clean up common prefixes prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"] for prefix in prefixes_to_remove: if response.lower().startswith(prefix.lower()): response = response[len(prefix):].strip() # Add assistant response to history history.append({"role": "assistant", "content": response}) return response except Exception as e: return f"Error generating response: {str(e)}" def generate_stream(self, user_input: str, history: List[Dict[str, str]], system_prompt: str, temperature: float = 0.7): """Generate a streaming response from the model""" if not self.model_loaded: yield "Model not loaded. Please try loading the model first." return try: # Add user message to history history.append({"role": "user", "content": user_input}) # Format messages messages = self.format_chat_history(history, system_prompt) # Apply chat template inputs = self.tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(self.device) # Generate streaming response generated_text = "" for token_id in self.model.generate( inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, streamer=None, ): # Decode current token new_token = self.tokenizer.decode(token_id[-1:], skip_special_tokens=True) generated_text += new_token # Extract only the new response response = generated_text if response.startswith(messages[0]["content"]): response = response[len(messages[0]["content"]):].strip() if user_input in response: response = response.replace(user_input, "").strip() # Clean up common prefixes prefixes_to_remove = ["Assistant:", "assistant:", "Response:", "response:"] for prefix in prefixes_to_remove: if response.lower().startswith(prefix.lower()): response = response[len(prefix):].strip() yield response # Stop if we hit end of sentence if new_token in ["", "<|endoftext|>", "."] and len(response) > 50: break # Add final response to history history.append({"role": "assistant", "content": response}) except Exception as e: yield f"Error generating response: {str(e)}" # Initialize chat model chat_model = MobileLLMChat() def load_model_button(version): """Load the model when button is clicked""" success = chat_model.load_model(version) if success: return gr.update(visible=False), gr.update(visible=True), gr.update(value="Model loaded successfully!") else: return gr.update(visible=True), gr.update(visible=False), gr.update(value="Failed to load model. Please check the logs.") def clear_chat(): """Clear the chat history""" return [], [] def chat_fn(message, history, system_prompt, temperature, model_version): """Main chat function""" if not chat_model.model_loaded: return "Please load the model first using the button above." # Convert history format formatted_history = [] for user_msg, assistant_msg in history: formatted_history.append({"role": "user", "content": user_msg}) if assistant_msg: formatted_history.append({"role": "assistant", "content": assistant_msg}) # Generate response response = chat_model.generate_response(message, formatted_history, system_prompt, temperature) return response def chat_stream_fn(message, history, system_prompt, temperature, model_version): """Streaming chat function""" if not chat_model.model_loaded: yield "Please load the model first using the button above." return # Convert history format formatted_history = [] for user_msg, assistant_msg in history: formatted_history.append({"role": "user", "content": user_msg}) if assistant_msg: formatted_history.append({"role": "assistant", "content": assistant_msg}) # Generate streaming response for chunk in chat_model.generate_stream(message, formatted_history, system_prompt, temperature): yield chunk # Create the Gradio interface with gr.Blocks( title="MobileLLM-Pro Chat", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 900px !important; margin: auto !important; } .message { padding: 12px !important; border-radius: 8px !important; margin-bottom: 8px !important; } .user-message { background-color: #e3f2fd !important; margin-left: 20% !important; } .assistant-message { background-color: #f5f5f5 !important; margin-right: 20% !important; } """ ) as demo: # Header gr.HTML("""
Built with anycoder
Chat with Facebook's MobileLLM-Pro model optimized for on-device inference
⚠️ Note: This model requires significant computational resources. Loading may take a few minutes.
Model: facebook/MobileLLM-Pro