Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import io | |
| import json | |
| import csv | |
| import numpy as np | |
| import faiss | |
| import uuid | |
| import time | |
| import sys | |
| from typing import List, Dict, Any | |
| import base64 | |
| # === HuggingFace 模型相關套件 (替換為 InferenceClient) === | |
| try: | |
| from huggingface_hub import InferenceClient | |
| except ImportError: | |
| st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install huggingface-hub") | |
| # === LangChain/RAG 相關套件 (保持不變) === | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_core.documents import Document | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.vectorstores.utils import DistanceStrategy | |
| from langchain_community.docstore.in_memory import InMemoryDocstore | |
| import pandas as pd | |
| # 嘗試匯入 pypdftry: | |
| try: | |
| import pypdf | |
| except ImportError: | |
| pypdf = None | |
| # --- 頁面設定 --- | |
| st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & IP Correlated Analysis)", page_icon="🛡️", layout="wide") | |
| st.title("🛡️ LLM with FAISS RAG & IP Correlated Analysis (Inference Client)") | |
| st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face Inference Client (API)**。支援 JSON/CSV/TXT/**W3C Log** 執行**IP 關聯批量分析**。") | |
| # --- Streamlit Session State 初始化 (修正強化,確保所有變數都有初始值) --- | |
| if 'execute_batch_analysis' not in st.session_state: | |
| st.session_state.execute_batch_analysis = False | |
| if 'batch_results' not in st.session_state: | |
| st.session_state.batch_results = [] # 確保初始化為空列表 | |
| if 'rag_current_file_key' not in st.session_state: | |
| st.session_state.rag_current_file_key = None | |
| if 'batch_current_file_key' not in st.session_state: | |
| st.session_state.batch_current_file_key = None | |
| if 'vector_store' not in st.session_state: | |
| st.session_state.vector_store = None | |
| if 'json_data_for_batch' not in st.session_state: | |
| st.session_state.json_data_for_batch = None # 保持 None,因為可能檔案沒上傳 | |
| # --- 定義模型列表 --- | |
| MODEL_OPTIONS = { | |
| "OpenAI GPT-OSS 20B (Hugging Face)": "openai/gpt-oss-20b", | |
| "secgpt (Hugging Face)": "clouditera/secgpt", | |
| "RNJ (Hugging Face)": "EssentialAI/rnj-1-instruct", | |
| "Meta Llama 3.1 8B Instruct (Hugging Face)": "meta-llama/Llama-3.1-8B-Instruct", | |
| "Meta Llama 3.3 70B Instruct (Hugging Face)": "meta-llama/Llama-3.3-70B-Instruct", | |
| "fdtn-ai Foundation-Sec 8B Instruct (Hugging Face)": "fdtn-ai/Foundation-Sec-8B-Instruct", | |
| "Qwen (Hugging Face)": "Qwen/Qwen2.5-7B-Instruct", | |
| "Gemma 3 27B Instruct (Hugging Face)": "google/gemma-3-27b-it", | |
| "Kimi K2 Instruct (Hugging Face)": "moonshotai/Kimi-K2-Instruct-0905" | |
| } | |
| WINDOW_SIZE = 10 # 關聯 Log 的最大數量 (包含當前 Log) | |
| # === W3C Log 專屬解析器 (新增) === | |
| def parse_w3c_log(log_content: str) -> List[Dict[str, Any]]: | |
| """ | |
| 解析 W3C Extended Log File Format (如 IIS Log),包括提取 #Fields:。 | |
| Args: | |
| log_content (str): Log 檔案的字串內容。 | |
| Returns: | |
| List[Dict[str, Any]]: 轉換後的 JSON 物件列表。 | |
| """ | |
| lines = log_content.splitlines() | |
| field_names = None | |
| data_lines = [] | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.startswith("#Fields:"): | |
| # 找到欄位定義,例如 "#Fields: date time s-ip cs-method ..." | |
| field_names = line.split()[1:] # 跳過 "#Fields:" 本身 | |
| elif not line.startswith("#"): | |
| # 這是實際的資料行 | |
| data_lines.append(line) | |
| if not field_names: | |
| # 如果沒有找到 #Fields,則退回到原始 Log 條目模式 | |
| # st.warning("未檢測到 W3C Log 的 #Fields: 標頭,退回原始 Log 條目模式。") | |
| return [{"raw_log_entry": line} for line in lines if line.strip() and not line.startswith("#")] | |
| json_data = [] | |
| # 定義需要轉換為數字的欄位名稱 (可根據您的需求擴充,使用底線版本) | |
| numeric_fields = ['sc_status', 'time_taken', 'bytes', 'resp_len', 'req_size'] | |
| for data_line in data_lines: | |
| # W3C Log 預設使用空格分隔。這裡使用 split() | |
| values = data_line.split(' ') | |
| # 簡易的欄位數量檢查 | |
| if len(values) != len(field_names): | |
| # 如果欄位數量不匹配,將該行視為原始 Log 條目 | |
| json_data.append({"raw_log_entry": data_line}) | |
| continue | |
| record = {} | |
| for key, value in zip(field_names, values): | |
| # 將 W3C 欄位名稱中的 '-' 替換成 Python 友好的 '_' | |
| key = key.strip().replace('-', '_') | |
| value = value.strip() if value else "" | |
| # 處理數字轉換 | |
| if key in numeric_fields: | |
| try: | |
| record[key] = int(value) | |
| except ValueError: | |
| try: | |
| record[key] = float(value) | |
| except ValueError: | |
| record[key] = value | |
| else: | |
| record[key] = value | |
| if record: | |
| json_data.append(record) | |
| return json_data | |
| # === 核心檔案轉換函式 (CSV/TXT -> JSON List) (保留並微調) === | |
| def convert_csv_txt_to_json_list(file_content: bytes, file_type: str) -> List[Dict[str, Any]]: | |
| """ | |
| 將 CSV 或 TXT 檔案內容 (假定為 CSV 格式,含標頭) 轉換為 JSON 物件列表。 | |
| """ | |
| log_content = file_content.decode("utf-8").strip() | |
| if not log_content: | |
| return [] | |
| string_io = io.StringIO(log_content) | |
| # 嘗試使用 csv.DictReader 自動將第一行視為 Key | |
| try: | |
| reader = csv.DictReader(string_io) | |
| except Exception: | |
| # 如果失敗,退回每行一個原始 Log 條目 | |
| return [{"raw_log_entry": line.strip()} for line in log_content.splitlines() if line.strip()] | |
| json_data = [] | |
| if reader and reader.fieldnames: | |
| # 使用者可能使用的數值欄位名稱 | |
| numeric_fields = ['sc-status', 'time-taken', 'bytes', 'resp-len', 'req-size', 'status_code', 'size', 'duration'] | |
| for row in reader: | |
| record = {} | |
| for key, value in row.items(): | |
| if key is None: continue | |
| key = key.strip() | |
| value = value.strip() if value else "" | |
| # 處理數字轉換 | |
| if key in numeric_fields: | |
| try: | |
| record[key] = int(value) | |
| except ValueError: | |
| try: | |
| record[key] = float(value) | |
| except ValueError: | |
| record[key] = value | |
| else: | |
| record[key] = value | |
| if record: | |
| json_data.append(record) | |
| # 再次檢查是否為空,如果是空,可能不是標準 CSV/JSON | |
| if not json_data: | |
| string_io.seek(0) | |
| lines = string_io.readlines() | |
| return [{"raw_log_entry": line.strip()} for line in lines if line.strip()] | |
| return json_data | |
| # === 檔案類型分發器 (已修改) === | |
| def convert_uploaded_file_to_json_list(uploaded_file) -> List[Dict[str, Any]]: | |
| """根據檔案類型,將上傳的檔案內容轉換為 Log JSON 列表。""" | |
| file_bytes = uploaded_file.getvalue() | |
| file_name_lower = uploaded_file.name.lower() | |
| # --- Case 1: JSON --- | |
| if file_name_lower.endswith('.json'): | |
| stringio = io.StringIO(file_bytes.decode("utf-8")) | |
| parsed_data = json.load(stringio) | |
| if isinstance(parsed_data, dict): | |
| # 處理包裹在 'alerts' 或 'logs' 鍵中的列表 | |
| if 'alerts' in parsed_data and isinstance(parsed_data['alerts'], list): | |
| return parsed_data['alerts'] | |
| elif 'logs' in parsed_data and isinstance(parsed_data['logs'], list): | |
| return parsed_data['logs'] | |
| else: | |
| return [parsed_data] # 單一字典視為單一 Log | |
| elif isinstance(parsed_data, list): | |
| return parsed_data # 列表直接返回 | |
| else: | |
| raise ValueError("JSON 檔案格式不支援 (非 List 或 Dict)。") | |
| # --- Case 2, 3, & 4: CSV/TXT/LOG --- | |
| elif file_name_lower.endswith(('.csv', '.txt', '.log')): | |
| file_type = 'csv' if file_name_lower.endswith('.csv') else ('log' if file_name_lower.endswith('.log') else 'txt') | |
| if file_type == 'log': | |
| # 針對 .log 檔案,嘗試使用 W3C 解析器 | |
| log_content = file_bytes.decode("utf-8").strip() | |
| if not log_content: return [] | |
| return parse_w3c_log(log_content) | |
| else: | |
| # CSV 和 TXT 保持使用原來的 csv.DictReader 邏輯 | |
| return convert_csv_txt_to_json_list(file_bytes, file_type) | |
| else: | |
| raise ValueError("不支援的檔案類型。") | |
| # --- 側邊欄設定 (已更新 'type' 參數) --- | |
| with st.sidebar: | |
| st.header("⚙️ 設定") | |
| # --- 新增模型選單 --- | |
| selected_model_name = st.selectbox( | |
| "選擇 LLM 模型", | |
| list(MODEL_OPTIONS.keys()), | |
| index=0 # 預設選擇第一個 | |
| ) | |
| MODEL_ID = MODEL_OPTIONS[selected_model_name] # 更新 MODEL_ID | |
| if not os.environ.get("HF_TOKEN"): | |
| st.error("環境變數 **HF_TOKEN** 未設定。請設定後重新啟動應用程式。") | |
| st.info(f"LLM 模型:**{MODEL_ID}** (Hugging Face Inference API)") | |
| st.warning("⚠️ **注意**: 該模型使用 Inference API 呼叫,請確保您的 HF Token 具有存取權限。") | |
| st.divider() | |
| st.subheader("📂 檔案上傳") | |
| # === 1. 批量分析檔案 (支援多種格式) === | |
| batch_uploaded_file = st.file_uploader( | |
| "1️⃣ 上傳 **Log 檔案** (用於批量分析)", | |
| type=['json', 'csv', 'txt', 'log'], # <--- 這裡增加了 'log' | |
| key="batch_uploader", | |
| help="支援 JSON (Array), CSV (含標題), TXT/LOG (視為 W3C 或一般 Log)" | |
| ) | |
| # === 2. RAG 知識庫檔案 === | |
| rag_uploaded_file = st.file_uploader( | |
| "2️⃣ 上傳 **RAG 參考知識庫** (Logs/PDF/Code 等)", | |
| type=['txt', 'py', 'log', 'csv', 'md', 'pdf'], # <--- 這裡增加了 'log' | |
| key="rag_uploader" | |
| ) | |
| st.divider() | |
| st.subheader("💡 批量分析指令") | |
| analysis_prompt = st.text_area( | |
| "針對每個 Log 執行的指令", | |
| value="You are a security expert tasked with analyzing logs related to Initial Access, Establish Foothold & Reconnaissance, Lateral Movement, Targeting & Data Exfiltration, Malware Deployment & Execution and Ransom & Negotiation. Respond with a clear, structured analysis using the following mandatory sections: \n\n- Priority: Provide the overall priority level. (Answer High-risk detected!, Medium-risk detected!, or Normal-Behavior detected! only) \n- Explanation: If this log is not normal behavior, explain the potential impact and why this specific log requires attention. If not, **omit the explanation section**. \n- Action Plan: If this log is not normal behavior, What should be the immediate steps to address this specific log? If not, **omit the action plan section**.", | |
| height=200 | |
| ) | |
| st.markdown("此指令將對檔案中的**每一個 Log 條目**執行一次獨立分析 (使用 **IP 關聯視窗**)。") | |
| if batch_uploaded_file: | |
| if st.button("🚀 執行批量分析"): | |
| if not os.environ.get("HF_TOKEN"): | |
| st.error("無法執行,環境變數 **HF_TOKEN** 未設定。") | |
| else: | |
| # 確保檔案已解析成功 | |
| if st.session_state.get('json_data_for_batch'): | |
| st.session_state.execute_batch_analysis = True | |
| else: | |
| st.error("請先等待 Log 檔案解析完成。") | |
| else: | |
| st.info("請上傳 Log 檔案以啟用批量分析按鈕。") | |
| st.divider() | |
| st.subheader("🔍 RAG 檢索設定") | |
| similarity_threshold = st.slider("📐 Cosine Similarity 門檻", 0.0, 1.0, 0.4, 0.01) | |
| st.divider() | |
| st.subheader("模型參數") | |
| system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst, named Ernest. You provide expert, authoritative, and concise advice on Information Security. Your analysis must be based strictly on the provided context.", height=100) | |
| max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128) | |
| temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1) | |
| top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05) | |
| st.divider() | |
| if st.button("🗑️ 清除所有紀錄"): | |
| # 僅清除動態狀態,保留 HF_TOKEN | |
| for key in list(st.session_state.keys()): | |
| if key not in ['HF_TOKEN']: | |
| del st.session_state[key] | |
| st.rerun() | |
| # --- 初始化 Hugging Face LLM Client (已更新,MODEL_ID 作為參數) --- | |
| # 確保 load_inference_client 接受 model_id 作為參數,以利用 Streamlit 的快取機制。 | |
| def load_inference_client(model_id): | |
| if not os.environ.get("HF_TOKEN"): return None | |
| try: | |
| client = InferenceClient(model_id, token=os.environ.get("HF_TOKEN")) | |
| st.success(f"Hugging Face Inference Client **{model_id}** 載入成功。") | |
| return client | |
| except Exception as e: | |
| st.error(f"Hugging Face Inference Client 載入失敗: {e}") | |
| return None | |
| inference_client = None | |
| if os.environ.get("HF_TOKEN"): | |
| with st.spinner(f"正在連線到 Inference Client: {MODEL_ID}..."): | |
| # 傳遞 MODEL_ID | |
| inference_client = load_inference_client(MODEL_ID) | |
| if inference_client is None and os.environ.get("HF_TOKEN"): | |
| st.warning(f"Hugging Face Inference Client **{MODEL_ID}** 無法連線。") | |
| elif not os.environ.get("HF_TOKEN"): | |
| st.error("請在環境變數中設定 HF_TOKEN。") | |
| # === Embedding 模型 (保持不變) === | |
| def load_embedding_model(): | |
| model_kwargs = {'device': 'cpu', 'trust_remote_code': True} | |
| encode_kwargs = {'normalize_embeddings': False} | |
| return HuggingFaceEmbeddings(model_name="BAAI/bge-large-zh-v1.5", model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) | |
| with st.spinner("正在載入 Embedding 模型..."): | |
| embedding_model = load_embedding_model() | |
| # === 建立向量庫 / Search 函數 (保持不變) === | |
| def process_file_to_faiss(uploaded_file): | |
| text_content = "" | |
| try: | |
| if uploaded_file.type == "application/pdf": | |
| if pypdf: | |
| pdf_reader = pypdf.PdfReader(uploaded_file) | |
| for page in pdf_reader.pages: | |
| text_content += page.extract_text() + "\n" | |
| else: return None, "PDF library missing" | |
| else: | |
| stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8")) | |
| text_content = stringio.read() | |
| if not text_content.strip(): return None, "File is empty" | |
| # 這裡將文件內容按行分割為 Document,每行一個 Document | |
| events = [line for line in text_content.splitlines() if line.strip()] | |
| docs = [Document(page_content=e) for e in events] | |
| if not docs: return None, "No documents created" | |
| # 進行 Embedding 和 FAISS 初始化 (IndexFlatIP + L2 normalization) | |
| embeddings = embedding_model.embed_documents([d.page_content for d in docs]) | |
| embeddings_np = np.array(embeddings).astype("float32") | |
| faiss.normalize_L2(embeddings_np) | |
| dimension = embeddings_np.shape[1] | |
| index = faiss.IndexFlatIP(dimension) # 使用內積 (Inner Product) | |
| index.add(embeddings_np) | |
| doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))] | |
| docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)}) | |
| index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)} | |
| # 使用 Cosine 距離策略,配合 IndexFlatIP 和 L2 normalization 達到 Cosine Similarity | |
| vector_store = FAISS(embedding_function=embedding_model, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id, distance_strategy=DistanceStrategy.COSINE) | |
| return vector_store, f"{len(docs)} chunks created." | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def faiss_cosine_search_all(vector_store, query, threshold): | |
| q_emb = embedding_model.embed_query(query) | |
| q_emb = np.array([q_emb]).astype("float32") | |
| faiss.normalize_L2(q_emb) | |
| index = vector_store.index | |
| D, I = index.search(q_emb, k=index.ntotal) | |
| selected = [] | |
| # Cosine Similarity = D (IndexFlatIP + L2 normalization) | |
| for score, idx in zip(D[0], I[0]): | |
| if idx == -1: continue | |
| # 檢測相似度是否超過門檻 | |
| if score >= threshold: | |
| doc_id = vector_store.index_to_docstore_id[idx] | |
| doc = vector_store.docstore.search(doc_id) | |
| selected.append((doc, score)) | |
| selected.sort(key=lambda x: x[1], reverse=True) | |
| return selected | |
| # === Hugging Face 生成單一 Log 分析回答 (保持不變) === | |
| def generate_rag_response_hf_for_log(client, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p): | |
| if client is None: return "ERROR: Client Error", "" | |
| context_text = "" | |
| # RAG 檢索邏輯 | |
| if vector_store: | |
| selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold) | |
| if selected: | |
| # 只取前 5 個最相關的片段 | |
| retrieved_contents = [f"--- Reference Chunk (sim={score:.3f}) ---\n{doc.page_content}" for i, (doc, score) in enumerate(selected[:5])] | |
| context_text = "\n".join(retrieved_contents) | |
| rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ==={context_text if context_text else 'No relevant reference context found.'}=== END REFERENCE CONTEXT ===ANALYSIS INSTRUCTION: {user_prompt}Based on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats.""" | |
| log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: Max {WINDOW_SIZE} logs associated by IP) ==={log_sequence_text}=== END LOG SEQUENCE ===""" | |
| messages = [ | |
| {"role": "system", "content": sys_prompt}, | |
| {"role": "user", "content": f"{rag_instruction}\n\n{log_content_section}"} | |
| ] | |
| try: | |
| # 使用 chat_completion 進行模型呼叫 | |
| response_stream = client.chat_completion( | |
| messages, | |
| max_tokens=max_output_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stream=False # 這裡使用非串流 | |
| ) | |
| if response_stream and response_stream.choices: | |
| return response_stream.choices[0].message.content.strip(), context_text | |
| else: return "Format Error: Model returned empty response or invalid format.", context_text | |
| except Exception as e: | |
| return f"Model Error: {str(e)}", context_text | |
| # ======================================================================= | |
| # === 檔案處理區塊 (RAG 檔案) - 保持不變 === | |
| if rag_uploaded_file: | |
| file_key = f"vs_{rag_uploaded_file.name}_{rag_uploaded_file.size}" | |
| if st.session_state.rag_current_file_key != file_key or 'vector_store' not in st.session_state: | |
| # 清除舊的 vector store 以節省內存 | |
| if 'vector_store' in st.session_state: | |
| del st.session_state.vector_store | |
| with st.spinner(f"正在建立 RAG 參考知識庫 ({rag_uploaded_file.name})..."): | |
| vs, msg = process_file_to_faiss(rag_uploaded_file) | |
| if vs: | |
| st.session_state.vector_store = vs | |
| st.session_state.rag_current_file_key = file_key | |
| st.toast(f"RAG 參考知識庫已更新!{msg}", icon="✅") | |
| else: | |
| st.session_state.rag_current_file_key = None | |
| st.error(msg) | |
| elif 'vector_store' in st.session_state: | |
| del st.session_state.vector_store | |
| del st.session_state.rag_current_file_key | |
| st.info("目前沒有 RAG 檔案,你可以上傳 RAG 作為參考知識庫。") | |
| # === 檔案處理區塊 (批量分析檔案 - **已更新** ) === | |
| if batch_uploaded_file: | |
| batch_file_key = f"batch_{batch_uploaded_file.name}_{batch_uploaded_file.size}" | |
| if st.session_state.batch_current_file_key != batch_file_key or 'json_data_for_batch' not in st.session_state: | |
| try: | |
| # 清除舊的數據 | |
| if 'json_data_for_batch' in st.session_state: | |
| del st.session_state.json_data_for_batch | |
| if 'batch_results' in st.session_state: | |
| del st.session_state.batch_results | |
| # 使用新的統一解析函式 | |
| parsed_data = convert_uploaded_file_to_json_list(batch_uploaded_file) | |
| if not parsed_data: | |
| raise ValueError(f"{batch_uploaded_file.name} 檔案載入失敗或內容為空。") | |
| # 儲存處理後的數據 | |
| st.session_state.json_data_for_batch = parsed_data | |
| st.session_state.batch_current_file_key = batch_file_key | |
| st.toast(f"檔案已解析並轉換為 {len(parsed_data)} 個 Log 條目。", icon="✅") | |
| except Exception as e: | |
| st.error(f"檔案解析錯誤: {e}") | |
| if 'json_data_for_batch' in st.session_state: | |
| del st.session_state.json_data_for_batch | |
| st.session_state.batch_current_file_key = None # 設置為 None 避免錯誤的 Key | |
| elif 'json_data_for_batch' in st.session_state: | |
| # 檔案被移除,清除相關數據 | |
| del st.session_state.json_data_for_batch | |
| del st.session_state.batch_current_file_key | |
| # 確保 batch_results 被清除,避免 'NoneType' 錯誤 | |
| if "batch_results" in st.session_state: | |
| del st.session_state.batch_results | |
| st.info("目前沒有批量分析檔案,請上傳日誌檔案以分析結果。") | |
| # === 執行批量分析邏輯 (已修改為 IP 關聯視窗) === | |
| if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.session_state and st.session_state.json_data_for_batch is not None: | |
| st.session_state.execute_batch_analysis = False | |
| start_time = time.time() | |
| # 這裡必須確保 st.session_state.batch_results 是 List,而不是 None | |
| if 'batch_results' not in st.session_state or st.session_state.batch_results is None: | |
| st.session_state.batch_results = [] | |
| st.session_state.batch_results = [] | |
| if inference_client is None: | |
| st.error("Client 未連線,無法執行。") | |
| else: | |
| logs_list = st.session_state.json_data_for_batch | |
| if logs_list: | |
| vs = st.session_state.get("vector_store", None) | |
| # 將 Log 條目轉換為 JSON 字串,用於 LLM 輸入 | |
| formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list] | |
| analysis_sequences = [] | |
| # --- 核心修改:基於 IP 關聯的 Log Sequence 建構 --- | |
| for i in range(len(formatted_logs)): | |
| current_log_entry = logs_list[i] | |
| current_log_str = formatted_logs[i] | |
| # 嘗試從當前 Log 條目中提取 IP 地址 (優先 W3C 格式,然後是一般日誌格式) | |
| # 使用者可以根據自己的日誌格式調整這裡的 Key | |
| target_ip = current_log_entry.get('c_ip') or current_log_entry.get('c-ip') or current_log_entry.get('remote_addr') or current_log_entry.get('source_ip') | |
| sequence_text = [] | |
| correlated_logs = [] | |
| if target_ip and target_ip != "-": # 假設 '-' 是 W3C 中的空值 | |
| # 篩選過去的 Log,最多 WINDOW_SIZE - 1 個,且 IP 必須匹配 | |
| # 從 i-1 倒序檢查到 0 | |
| for j in range(i - 1, -1, -1): | |
| prior_log_entry = logs_list[j] | |
| prior_ip = prior_log_entry.get('c_ip') or prior_log_entry.get('c-ip') or prior_log_entry.get('remote_addr') or prior_log_entry.get('source_ip') | |
| # 檢查 IP 是否匹配 | |
| if prior_ip == target_ip: | |
| # 插入到最前面,保持時間順序 | |
| correlated_logs.insert(0, formatted_logs[j]) | |
| # 限制累積的 Log 數量(不包含當前 Log) | |
| if len(correlated_logs) >= WINDOW_SIZE - 1: | |
| break | |
| # 1. 加入相關聯的 Log (時間較早的) | |
| for j, log_str in enumerate(correlated_logs): | |
| # log_idx 是這些 Log 在 logs_list 中的原始索引 (不完全準確,但提供參考) | |
| sequence_text.append(f"--- Correlated Log Index (IP:{target_ip}) ---\n{log_str}") | |
| else: | |
| # 如果沒有找到 IP,只分析當前 Log (確保 sequence_text 不是空的) | |
| st.warning(f"Log #{i+1} 找不到 IP 欄位 ({target_ip}),僅分析當前 Log 條目。") | |
| # 2. 加入當前的目標 Log | |
| sequence_text.append(f"--- TARGET LOG TO ANALYZE (Index {i+1}) ---\n{current_log_str}") | |
| analysis_sequences.append({ | |
| "sequence_text": "\n\n".join(sequence_text), | |
| "target_log_id": i + 1, | |
| "original_log_entry": logs_list[i] | |
| }) | |
| # --- LLM 執行迴圈 --- | |
| total_sequences = len(analysis_sequences) | |
| st.header(f"⚡ 批量分析執行中 (IP 關聯視窗 $N={WINDOW_SIZE}$)...") | |
| progress_bar = st.progress(0, text=f"準備處理 {total_sequences} 個序列...") | |
| results_container = st.container() | |
| for i, seq_data in enumerate(analysis_sequences): | |
| log_id = seq_data["target_log_id"] | |
| progress_bar.progress((i + 1) / total_sequences, text=f"Processing {i + 1}/{total_sequences} (Log #{log_id})...") | |
| try: | |
| response, retrieved_ctx = generate_rag_response_hf_for_log( | |
| client=inference_client, | |
| model_id=MODEL_ID, | |
| log_sequence_text=seq_data["sequence_text"], | |
| user_prompt=analysis_prompt, | |
| sys_prompt=system_prompt, | |
| vector_store=vs, | |
| threshold=similarity_threshold, | |
| max_output_tokens=max_output_tokens, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| item = { | |
| "log_id": log_id, | |
| "log_content": seq_data["original_log_entry"], | |
| "sequence_analyzed": seq_data["sequence_text"], | |
| "analysis_result": response, | |
| "context": retrieved_ctx | |
| } | |
| st.session_state.batch_results.append(item) | |
| csv_string = pd.json_normalize(item) | |
| csv_string = csv_string.to_csv(header=True, index=False, encoding='utf-8') | |
| with results_container: | |
| # 呈現 LLM 分析結果 | |
| is_high = any(x in response.lower() for x in ['high-risk detected!']) | |
| is_medium = any(x in response.lower() for x in ['medium-risk detected!']) | |
| if is_high: | |
| b64 = base64.b64encode(csv_string.encode('utf-8')) | |
| st.markdown( | |
| f""" | |
| <h3> | |
| <a href="data:application/octet-stream;base64,{b64.decode()}" download="Log_{item['log_id']}.csv">Log #{item['log_id']} (HIGH RISK DETECTED)</a> | |
| </h3> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| with st.expander("序列內容 (JSON Format)"): | |
| st.code(item["sequence_analyzed"], language='json') | |
| st.error(item['analysis_result']) | |
| st.markdown("---") | |
| if item['context']: | |
| with st.expander("參考 RAG 片段"): st.code(item['context']) | |
| except Exception as e: | |
| st.error(f"Error Log {log_id}: {e}") | |
| end_time = time.time() | |
| progress_bar.empty() | |
| st.success(f"完成!耗時 {end_time - start_time:.2f} 秒。") | |
| else: | |
| st.error("無法提取有效 Log,請檢查檔案格式。") | |
| # === 顯示結果 (歷史紀錄) - 保持不變,但加固了 session state 檢查 === | |
| if st.session_state.get("batch_results") and isinstance(st.session_state.batch_results, list) and st.session_state.batch_results and not st.session_state.execute_batch_analysis: | |
| st.header("⚡ 歷史分析結果") | |
| high_risk_data = [] | |
| high_risk_items = [] | |
| medium_risk_data = [] | |
| medium_risk_items = [] | |
| for item in st.session_state.batch_results: | |
| # 檢查 analysis_result 中是否包含 'High-risk detected' (不區分大小寫) | |
| is_high_risk = 'high-risk detected!' in item['analysis_result'].lower() | |
| is_medium_risk = 'medium-risk detected!' in item['analysis_result'].lower() | |
| if is_high_risk: | |
| high_risk_items.append(item) | |
| # --- 為 CSV 報告準備數據 --- | |
| log_content_str = json.dumps(item["log_content"], ensure_ascii=False) | |
| analysis_result_clean = item['analysis_result'].replace('\n', ' | ') | |
| high_risk_data.append({ | |
| "Log_ID": item['log_id'], | |
| "Risk_Level": "HIGH_RISK", | |
| "Log_Content": log_content_str, | |
| "AI_Analysis_Result": analysis_result_clean | |
| }) | |
| report_container = st.container() | |
| with report_container: | |
| # 顯示 High-Risk 報告的下載按鈕 (改為 CSV 邏輯) | |
| if high_risk_items: | |
| st.success(f"✅ 檢測到 {len(high_risk_items)} 條高風險 Log。") | |
| # --- 構建 CSV 內容 --- | |
| csv_output = io.StringIO() | |
| csv_output.write("Log_ID,Risk_Level,Log_Content,AI_Analysis_Result\n") | |
| def escape_csv(value): | |
| # 替換內容中的所有雙引號為兩個雙引號,然後用雙引號包圍 | |
| return f'"{str(value).replace('"', '""')}"' | |
| for row in high_risk_data: | |
| line = ",".join([ | |
| str(row["Log_ID"]), | |
| row["Risk_Level"], | |
| escape_csv(row["Log_Content"]), | |
| escape_csv(row["AI_Analysis_Result"]) | |
| ]) + "\n" | |
| csv_output.write(line) | |
| csv_content = csv_output.getvalue() | |
| # 顯示 CSV 報告的下載按鈕 | |
| def create_download_link(val, filename): | |
| # 1. 將輸入值編碼為 Base64 | |
| b64 = base64.b64encode(val.encode('utf-8')).decode() | |
| # 2. 定義樣式 | |
| # 這裡使用的樣式模仿了一個帶有藍色背景的按鈕 | |
| button_style = """ | |
| display: inline-block; | |
| padding: 10px 20px; | |
| margin: 5px 0; | |
| background-color: #007bff; /* 藍色背景 */ | |
| color: white; /* 白色文字 */ | |
| text-align: center; | |
| text-decoration: none; /* 移除下劃線 */ | |
| border-radius: 5px; /* 圓角 */ | |
| border: none; | |
| cursor: pointer; | |
| font-weight: bold; | |
| """ | |
| # 3. 組合 HTML 連結 | |
| # 注意:在 f-string 中要小心引號的使用 | |
| html_link = f""" | |
| <a | |
| href="data:application/octet-stream;base64,{b64}" | |
| download="{filename}.csv" | |
| style="{button_style}" | |
| > | |
| 🔽 下載檔案(高度風險) | |
| </a> | |
| """ | |
| # 移除多餘的換行和空格,確保在某些環境中正確顯示 | |
| return html_link.strip() | |
| download_url = create_download_link(csv_content, 'high_risk_report') | |
| st.markdown(download_url, unsafe_allow_html=True) | |
| if not high_risk_items and not medium_risk_items: | |
| st.info("👍 未檢測到任何標註為 High/Medium-risk 的 Log。") | |