import streamlit as st import pdfplumber import os import tempfile from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.llms.base import LLM from typing import Optional, List, Mapping, Any import json import urllib.request class AzureEndpointLLM(LLM): endpoint_url: str api_key: str def __init__(self, **kwargs): # Pull secrets from env kwargs["endpoint_url"] = os.environ.get("AZURE_ENDPOINT_URL") kwargs["api_key"] = os.environ.get("AZURE_API_KEY") if not kwargs["endpoint_url"] or not kwargs["api_key"]: raise ValueError("Missing Azure endpoint URL or API key") super().__init__(**kwargs) @property def _llm_type(self) -> str: return "azure-endpoint-llm" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[Any] = None, **kwargs: Any, ) -> str: # Format expected by Azure LLM endpoint input_data = { "input_data": { "input_string": [{"role": "user", "content": prompt}], "parameters": { "temperature": 0.7, "max_tokens": 1024 } } } body = str.encode(json.dumps(input_data)) headers = { 'Content-Type': 'application/json', 'Accept': 'application/json', 'Authorization': 'Bearer ' + self.api_key } req = urllib.request.Request(self.endpoint_url, body, headers) try: response = urllib.request.urlopen(req) result = response.read().decode("utf-8") parsed = json.loads(result) # Adjust this depending on how Azure returns generated text if isinstance(parsed, dict) and "output" in parsed: return parsed["output"] elif isinstance(parsed, dict): return json.dumps(parsed) return parsed except urllib.error.HTTPError as e: print("Request failed:", e.code) print(e.read().decode("utf-8", "ignore")) return "Oops, there was an error calling the Azure model." class ChatHistory: def __init__(self): self.messages = [] def add_user_message(self, message): self.messages.append({"role": "user", "content": message}) def add_assistant_message(self, message, sources=None): self.messages.append({ "role": "assistant", "content": message, "sources": sources if sources else [] }) def get_conversation_history(self, include_sources=False): if include_sources: return self.messages else: return [{"role": m["role"], "content": m["content"]} for m in self.messages] def get_messages_for_display(self): return self.messages def clear(self): self.messages = [] st.set_page_config(page_title="RAG Chat with Azure LLM", page_icon="💬", layout="wide") if 'vector_store' not in st.session_state: st.session_state.vector_store = None if 'document_processed' not in st.session_state: st.session_state.document_processed = False if 'file_name' not in st.session_state: st.session_state.file_name = None if 'document_text' not in st.session_state: st.session_state.document_text = "" if 'chat_history' not in st.session_state: st.session_state.chat_history = ChatHistory() def extract_text_from_document(document_file): file_type = document_file.name.split('.')[-1].lower() if file_type == 'txt': return document_file.getvalue().decode('utf-8') elif file_type == 'pdf': with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(document_file.getvalue()) tmp_file_path = tmp_file.name text = "" try: with pdfplumber.open(tmp_file_path) as pdf: for page in pdf.pages: page_text = page.extract_text() if page_text: text += page_text + "\n\n" except Exception as e: st.error(f"Error extracting text from PDF: {e}") finally: if os.path.exists(tmp_file_path): os.remove(tmp_file_path) return text else: st.error(f"Unsupported file type: {file_type}") return "" def create_chunks(text): text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len) return text_splitter.split_text(text) def create_vector_store(chunks): embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={'device': 'cpu'}) return FAISS.from_texts(chunks, embeddings) def retrieve_relevant_chunks(vector_store, query, k=3): if not vector_store: return [] return vector_store.similarity_search(query, k=k) def generate_rag_response(query, chat_history, vector_store): llm = AzureEndpointLLM() relevant_docs = retrieve_relevant_chunks(vector_store, query, k=3) if not relevant_docs: return "I couldn't find any relevant information in the document to answer your question.", [] context = "\n\n".join([doc.page_content for doc in relevant_docs]) conversation_history = "" for msg in chat_history.get_conversation_history(): role = "User" if msg["role"] == "user" else "Assistant" conversation_history += f"{role}: {msg['content']}\n\n" prompt = f""" You are a helpful assistant that provides accurate information based only on the given context and conversation history. 1. Use only the context below and the conversation history to answer the question. 2. If the answer is not in the context, reply with \"I don't have enough information to answer this question.\" 3. Be friendly and helpful. 4. Maintain continuity with the conversation history. Conversation History: {conversation_history} Context from document: {context} User's question: {query} Answer: """ response = llm(prompt) return response, relevant_docs def process_user_message(user_message): st.session_state.chat_history.add_user_message(user_message) with st.spinner("Thinking..."): response, source_docs = generate_rag_response(user_message, st.session_state.chat_history, st.session_state.vector_store) sources = [{"id": i + 1, "content": doc.page_content} for i, doc in enumerate(source_docs)] st.session_state.chat_history.add_assistant_message(response, sources) return response, sources st.title("💬 RAG Chat with Azure Model") st.markdown(""" Upload a PDF or TXT document and chat about its content. This system uses: - Document text extraction - Text chunking and embedding - Azure LLM for answering questions - Memory to maintain conversation context """) with st.sidebar: st.header("Configuration") uploaded_file = st.file_uploader("Upload a document", type=['pdf', 'txt']) if st.button("Clear Chat History"): st.session_state.chat_history.clear() st.success("Chat history cleared!") st.markdown("**Using Azure-deployed model**") st.markdown("---") st.markdown("### About") st.markdown(""" This is a RAG Chat system that: 1. Processes PDF and TXT documents 2. Creates a vector database of document content 3. Maintains conversation history 4. Retrieves relevant information for user queries 5. Generates contextual answers using your Azure-deployed LLM """) if uploaded_file is not None: if st.session_state.file_name != uploaded_file.name: st.session_state.file_name = uploaded_file.name st.session_state.document_processed = False if not st.session_state.document_processed: with st.spinner(f"Processing {uploaded_file.name.split('.')[-1].upper()} file..."): text = extract_text_from_document(uploaded_file) st.session_state.document_text = text chunks = create_chunks(text) st.session_state.vector_store = create_vector_store(chunks) st.session_state.document_processed = True st.success(f"Document processed successfully: {uploaded_file.name}") num_chunks = len(chunks) avg_chunk_size = sum(len(chunk) for chunk in chunks) / num_chunks if num_chunks > 0 else 0 st.info(f"Document processed into {num_chunks} chunks with average size of {avg_chunk_size:.0f} characters") col1, col2 = st.columns([3, 1]) with col1: st.subheader("Chat") chat_container = st.container() with chat_container: for message in st.session_state.chat_history.get_messages_for_display(): with st.chat_message(message["role"]): st.markdown(message["content"]) if message["role"] == "assistant" and "sources" in message and message["sources"]: with st.expander("View Sources"): for source in message["sources"]: st.markdown(f"**Source {source['id']}**") st.text(source["content"]) if st.session_state.document_processed: user_input = st.chat_input("Type your message here...") if user_input: with st.chat_message("user"): st.markdown(user_input) response, sources = process_user_message(user_input) with st.chat_message("assistant"): st.markdown(response) if sources: with st.expander("View Sources"): for source in sources: st.markdown(f"**Source {source['id']}**") st.text(source["content"]) else: st.info("Please upload a document to start chatting") with col2: if st.session_state.document_processed: st.subheader("Document Preview") with st.expander("View Document Text", expanded=False): st.text_area( "Extracted Text", st.session_state.document_text[:5000] + ("..." if len(st.session_state.document_text) > 5000 else ""), height=400) else: st.info("Upload a PDF or TXT document to get started") if not st.session_state.document_processed: st.markdown(""" ## Getting Started 1. **Upload a PDF or TXT document** using the file uploader in the sidebar 2. Wait for the document to be processed 3. Start chatting with the AI about the document 4. The chat remembers the conversation context 5. Clear the chat history using the button in the sidebar The system uses an Azure-deployed LLM and maintains conversation memory. """) st.sidebar.info(""" **Using Azure ML Endpoint** This application calls a large language model deployed on Azure. """)