""" 路由器和评分器模块 包含查询路由、文档相关性评分、答案质量评分和幻觉检测 """ try: from langchain_core.prompts import PromptTemplate except ImportError: from langchain.prompts import PromptTemplate from langchain_community.chat_models import ChatOllama from langchain_core.output_parsers import JsonOutputParser, StrOutputParser from config import LOCAL_LLM class QueryRouter: """查询路由器,决定使用向量存储还是网络搜索""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。 对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。 你不需要严格匹配问题中与这些主题相关的关键词。 否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。 返回一个只包含'datasource'键的JSON,不要前言或解释。 要路由的问题:{question}""", input_variables=["question"], ) self.router = self.prompt | self.llm | JsonOutputParser() def route(self, question: str) -> str: """路由问题到相应的数据源""" result = self.router.invoke({"question": question}) return result.get("datasource", "web_search") class DocumentGrader: """文档相关性评分器""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个评分员,评估检索到的文档是否与用户问题相关。 如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。 给出二进制分数'yes'或'no',以表明文档是否与问题相关。 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。 检索到的文档: {document} 用户问题:{question}""", input_variables=["question", "document"], ) self.grader = self.prompt | self.llm | JsonOutputParser() def grade(self, question: str, document: str) -> str: """评估文档与问题的相关性""" result = self.grader.invoke({"question": question, "document": document}) return result.get("score", "no") class AnswerGrader: """答案质量评分器""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个评分员,评估答案是否有助于解决问题。 这里是答案: \n ------- \n {generation} \n ------- \n 这里是问题:{question} 给出二进制分数'yes'或'no',表示答案是否有助于解决问题。 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""", input_variables=["generation", "question"], ) self.grader = self.prompt | self.llm | JsonOutputParser() def grade(self, question: str, generation: str) -> str: """评估答案质量""" result = self.grader.invoke({"question": question, "generation": generation}) return result.get("score", "no") class HallucinationGrader: """ 幻觉检测器 - 使用专业模型(Vectara + NLI) 相比 LLM-as-a-Judge 方法: - 准确率从 60-75% 提升到 85-95% - 速度提升 5-10 倍 - 成本降低 90% """ def __init__(self, method: str = "hybrid"): """ 初始化幻觉检测器 Args: method: 'vectara', 'nli', 或 'hybrid' (推荐) """ # 尝试加载专业检测模型 try: from hallucination_detector import initialize_hallucination_detector self.detector = initialize_hallucination_detector(method=method) self.use_professional_detector = True print(f"✅ 使用专业幻觉检测器: {method}") except Exception as e: print(f"⚠️ 专业检测器加载失败,回退到 LLM 方法: {e}") self.use_professional_detector = False # 回退到原有的 LLM 方法 self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) self.prompt = PromptTemplate( template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。 给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。 检索到的文档: {documents} LLM生成:{generation}""", input_variables=["generation", "documents"], ) self.grader = self.prompt | self.llm | JsonOutputParser() def grade(self, generation: str, documents) -> str: """ 检测生成内容是否存在幻觉 Args: generation: LLM 生成的内容 documents: 参考文档 Returns: "yes" 表示无幻觉,"no" 表示有幻觉 """ if self.use_professional_detector: # 使用专业检测器 return self.detector.grade(generation, documents) else: # 回退到 LLM 方法 result = self.grader.invoke({"generation": generation, "documents": documents}) return result.get("score", "no") class QueryRewriter: """查询重写器,优化查询以获得更好的检索结果""" def __init__(self): self.llm = ChatOllama(model=LOCAL_LLM, temperature=0) self.prompt = PromptTemplate( template="""你是一个问题重写器,将输入问题转换为更适合向量存储检索的更好版本。 查看初始问题并制定一个改进的问题。 这里是初始问题:\n\n {question}。改进的问题(无前言):\n """, input_variables=["question"], ) self.rewriter = self.prompt | self.llm | StrOutputParser() def rewrite(self, question: str) -> str: """重写查询以获得更好的检索效果""" print(f"---原始查询: {question}---") rewritten_query = self.rewriter.invoke({"question": question}) print(f"---重写查询: {rewritten_query}---") return rewritten_query def initialize_graders_and_router(): """初始化所有评分器和路由器""" # Load detection method from config try: from hallucination_config import HALLUCINATION_DETECTION_METHOD detection_method = HALLUCINATION_DETECTION_METHOD except ImportError: detection_method = "hybrid" # Default to hybrid query_router = QueryRouter() document_grader = DocumentGrader() answer_grader = AnswerGrader() hallucination_grader = HallucinationGrader(method=detection_method) query_rewriter = QueryRewriter() return { "query_router": query_router, "document_grader": document_grader, "answer_grader": answer_grader, "hallucination_grader": hallucination_grader, "query_rewriter": query_rewriter }