lanny xu commited on
Commit
ef805fe
·
1 Parent(s): 7787f0c

modifies bug

Browse files
Files changed (7) hide show
  1. .env +26 -1
  2. config.py +22 -0
  3. document_processor.py +314 -39
  4. main.py +6 -0
  5. requirements.txt +1 -0
  6. setup_and_run.py +2 -0
  7. workflow_nodes.py +32 -5
.env CHANGED
@@ -1,2 +1,27 @@
1
  TAVILY_API_KEY="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"
2
- # NOMIC_API_KEY="nk-kt4Tu3UdwFpIlDdxLcd9AK3a7cfdAKhoXvPbJ78oVlE"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  TAVILY_API_KEY="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"
2
+ # NOMIC_API_KEY="nk-kt4Tu3UdwFpIlDdxLcd9AK3a7cfdAKhoXvPbJ78oVlE"
3
+
4
+ # 混合检索配置
5
+ ENABLE_HYBRID_SEARCH=true
6
+ BM25_K1=1.5
7
+ BM25_B=0.75
8
+ ENSEMBLE_WEIGHTS=[0.5, 0.5]
9
+
10
+ # 查询扩展配置
11
+ ENABLE_QUERY_EXPANSION=true
12
+ QUERY_EXPANSION_MODEL="all-MiniLM-L6-v2"
13
+ QUERY_EXPANSION_TOP_K=5
14
+
15
+ # 多模态配置
16
+ ENABLE_MULTIMODAL=true
17
+ MULTIMODAL_MODEL="openai/clip-vit-base-patch32"
18
+ MULTIMODAL_IMAGE_MODEL="openai/clip-vit-base-patch32"
19
+
20
+ # GraphRAG配置
21
+ ENABLE_GRAPH_RAG=true
22
+ GRAPH_ENTITY_EXTRACTION_MODEL="llama2"
23
+ GRAPH_RELATION_EXTRACTION_MODEL="llama2"
24
+ GRAPH_COMMUNITY_DETECTION=true
25
+ GRAPH_COMMUNITY_ALGORITHM="louvain"
26
+ GRAPH_VISUALIZATION=true
27
+ GRAPH_LAYOUT="spring"
config.py CHANGED
@@ -75,6 +75,28 @@ GRAPHRAG_MAX_HOPS = 2 # 本地查询最大跳数
75
  GRAPHRAG_TOP_K_COMMUNITIES = 5 # 全局查询使用的社区数量
76
  GRAPHRAG_BATCH_SIZE = 10 # 实体提取批处理大小
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def get_api_keys():
80
  """获取API密钥并返回字典"""
 
75
  GRAPHRAG_TOP_K_COMMUNITIES = 5 # 全局查询使用的社区数量
76
  GRAPHRAG_BATCH_SIZE = 10 # 实体提取批处理大小
77
 
78
+ # 混合检索策略配置
79
+ ENABLE_HYBRID_SEARCH = True # 是否启用混合检索策略
80
+ HYBRID_SEARCH_WEIGHTS = {"vector": 0.7, "keyword": 0.3} # 向量检索和关键词检索的权重
81
+ KEYWORD_SEARCH_K = 5 # 关键词检索返回的文档数量
82
+ BM25_K1 = 1.2 # BM25算法的k1参数
83
+ BM25_B = 0.75 # BM25算法的b参数
84
+
85
+ # 查询扩展优化配置
86
+ ENABLE_QUERY_EXPANSION = True # 是否启用查询扩展
87
+ QUERY_EXPANSION_MODEL = "mistral" # 用于查询扩展的模型
88
+ QUERY_EXPANSION_PROMPT = """请为以下查询生成3-5个相关的扩展查询,这些查询应该从不同角度探索原始查询的主题。
89
+ 原始查询: {query}
90
+ 扩展查询: """ # 查询扩展提示模板
91
+ MAX_EXPANDED_QUERIES = 3 # 最多使用的扩展查询数量
92
+
93
+ # 多模态支持配置
94
+ ENABLE_MULTIMODAL = True # 是否启用多模态支持
95
+ MULTIMODAL_IMAGE_MODEL = "openai/clip-vit-base-patch32" # 图像嵌入模型
96
+ SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "gif", "bmp"] # 支持的图像格式
97
+ IMAGE_EMBEDDING_DIM = 512 # 图像嵌入维度
98
+ MULTIMODAL_WEIGHTS = {"text": 0.7, "image": 0.3} # 文本和图像检索的权重
99
+
100
 
101
  def get_api_keys():
102
  """获取API密钥并返回字典"""
document_processor.py CHANGED
@@ -11,16 +11,42 @@ except ImportError:
11
  from langchain_community.document_loaders import WebBaseLoader
12
  from langchain_community.vectorstores import Chroma
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
14
 
15
  from config import (
16
  KNOWLEDGE_BASE_URLS,
17
  CHUNK_SIZE,
18
  CHUNK_OVERLAP,
19
  COLLECTION_NAME,
20
- EMBEDDING_MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
  from reranker import create_reranker
23
 
 
 
 
 
 
 
 
24
 
25
  class DocumentProcessor:
26
  """文档处理器类,负责文档加载、处理和向量化"""
@@ -56,10 +82,20 @@ class DocumentProcessor:
56
 
57
  self.vectorstore = None
58
  self.retriever = None
 
 
59
 
60
  # 初始化重排器
61
  self.reranker = None
62
  self._setup_reranker()
 
 
 
 
 
 
 
 
63
 
64
  def _setup_reranker(self):
65
  """
@@ -86,6 +122,43 @@ class DocumentProcessor:
86
  print(f"⚠️ 重排器初始化完全失败: {e2}")
87
  print("⚠️ 将使用基础检索,不进行重排")
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def load_documents(self, urls=None):
90
  """从URL加载文档"""
91
  if urls is None:
@@ -113,6 +186,30 @@ class DocumentProcessor:
113
  embedding=self.embeddings,
114
  )
115
  self.retriever = self.vectorstore.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  print("向量数据库创建完成")
117
  return self.vectorstore, self.retriever
118
 
@@ -133,31 +230,164 @@ class DocumentProcessor:
133
  # 返回doc_splits用于GraphRAG索引
134
  return vectorstore, retriever, doc_splits
135
 
136
- def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20):
137
- """增强检索:先检索更多候选,然后重排"""
138
- if not self.retriever:
139
- print("⚠️ 检索器未初始化")
140
- return []
141
-
142
- # 1. 初始检索:获取更多候选文档 (使用 invoke 替代 get_relevant_documents)
143
- initial_docs = self.retriever.invoke(query)
144
-
145
- # 获取更多候选(如果可能)
146
- if hasattr(self.retriever, 'search_kwargs'):
147
- # 修改检索参数以获取更多结果
148
- original_k = self.retriever.search_kwargs.get('k', 4)
149
- self.retriever.search_kwargs['k'] = min(rerank_candidates, len(initial_docs))
150
- candidate_docs = self.retriever.invoke(query)
151
- self.retriever.search_kwargs['k'] = original_k # 恢复原设置
152
- else:
153
- candidate_docs = initial_docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- print(f"初始检索获得 {len(candidate_docs)} 个候选文档")
 
156
 
157
- # 2. 重排(如果重排器可用)
158
- if self.reranker and len(candidate_docs) > top_k:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  try:
160
- reranked_results = self.reranker.rerank(query, candidate_docs, top_k)
161
  final_docs = [doc for doc, score in reranked_results]
162
  scores = [score for doc, score in reranked_results]
163
 
@@ -167,40 +397,85 @@ class DocumentProcessor:
167
  return final_docs
168
  except Exception as e:
169
  print(f"⚠️ 重排失败: {e},使用原始检索结果")
170
- return candidate_docs[:top_k]
171
  else:
172
  # 不重排或候选数量不足
173
- return candidate_docs[:top_k]
174
 
175
- def compare_retrieval_methods(self, query: str, top_k: int = 5):
176
  """比较不同检索方法的效果"""
177
  if not self.retriever:
178
  return {}
179
 
 
 
 
 
 
180
  # 原始检索 (使用 invoke 替代 get_relevant_documents)
181
  original_docs = self.retriever.invoke(query)[:top_k]
 
 
 
 
 
 
 
182
 
183
- # 增强检索(带重排)
184
- enhanced_docs = self.enhanced_retrieve(query, top_k)
 
 
 
 
 
 
 
 
185
 
186
- return {
187
- 'query': query,
188
- 'original_retrieval': {
189
- 'count': len(original_docs),
 
190
  'documents': [{
191
  'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
192
  'metadata': getattr(doc, 'metadata', {})
193
- } for doc in original_docs]
194
- },
195
- 'enhanced_retrieval': {
196
- 'count': len(enhanced_docs),
 
 
 
 
197
  'documents': [{
198
  'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
199
  'metadata': getattr(doc, 'metadata', {})
200
- } for doc in enhanced_docs]
201
- },
202
- 'reranker_used': self.reranker is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  }
 
 
204
 
205
  def format_docs(self, docs):
206
  """格式化文档用于生成"""
 
11
  from langchain_community.document_loaders import WebBaseLoader
12
  from langchain_community.vectorstores import Chroma
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
14
+ from langchain.retrievers import BM25Retriever
15
+ from langchain_community.retrievers import EnsembleRetriever
16
 
17
  from config import (
18
  KNOWLEDGE_BASE_URLS,
19
  CHUNK_SIZE,
20
  CHUNK_OVERLAP,
21
  COLLECTION_NAME,
22
+ EMBEDDING_MODEL,
23
+ # 混合检索配置
24
+ ENABLE_HYBRID_SEARCH,
25
+ HYBRID_SEARCH_WEIGHTS,
26
+ KEYWORD_SEARCH_K,
27
+ BM25_K1,
28
+ BM25_B,
29
+ # 查询扩展配置
30
+ ENABLE_QUERY_EXPANSION,
31
+ QUERY_EXPANSION_MODEL,
32
+ QUERY_EXPANSION_PROMPT,
33
+ MAX_EXPANDED_QUERIES,
34
+ # 多模态配置
35
+ ENABLE_MULTIMODAL,
36
+ MULTIMODAL_IMAGE_MODEL,
37
+ SUPPORTED_IMAGE_FORMATS,
38
+ IMAGE_EMBEDDING_DIM,
39
+ MULTIMODAL_WEIGHTS
40
  )
41
  from reranker import create_reranker
42
 
43
+ # 多模态支持相关导入
44
+ import base64
45
+ import io
46
+ from PIL import Image
47
+ import numpy as np
48
+ from typing import List, Dict, Any, Optional, Union
49
+
50
 
51
  class DocumentProcessor:
52
  """文档处理器类,负责文档加载、处理和向量化"""
 
82
 
83
  self.vectorstore = None
84
  self.retriever = None
85
+ self.bm25_retriever = None # BM25检索器
86
+ self.ensemble_retriever = None # 集成检索器
87
 
88
  # 初始化重排器
89
  self.reranker = None
90
  self._setup_reranker()
91
+
92
+ # 初始化多模态支持
93
+ self.image_embeddings_model = None
94
+ self._setup_multimodal()
95
+
96
+ # 初始化查询扩展
97
+ self.query_expansion_model = None
98
+ self._setup_query_expansion()
99
 
100
  def _setup_reranker(self):
101
  """
 
122
  print(f"⚠️ 重排器初始化完全失败: {e2}")
123
  print("⚠️ 将使用基础检索,不进行重排")
124
 
125
+ def _setup_multimodal(self):
126
+ """设置多模态支持"""
127
+ if not ENABLE_MULTIMODAL:
128
+ print("⚠️ 多模态支持已禁用")
129
+ return
130
+
131
+ try:
132
+ print("🔧 正在初始化多模态支持...")
133
+ from transformers import CLIPProcessor, CLIPModel
134
+ import torch
135
+
136
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
137
+ self.image_embeddings_model = CLIPModel.from_pretrained(MULTIMODAL_IMAGE_MODEL).to(device)
138
+ self.image_processor = CLIPProcessor.from_pretrained(MULTIMODAL_IMAGE_MODEL)
139
+ print(f"✅ 多模态支持初始化成功 (设备: {device})")
140
+ except Exception as e:
141
+ print(f"⚠️ 多模态支持初始化失败: {e}")
142
+ print("⚠️ 将仅使用文本检索")
143
+ self.image_embeddings_model = None
144
+
145
+ def _setup_query_expansion(self):
146
+ """设置查询扩展"""
147
+ if not ENABLE_QUERY_EXPANSION:
148
+ print("⚠️ 查询扩展已禁用")
149
+ return
150
+
151
+ try:
152
+ print("🔧 正在初始化查询扩展...")
153
+ from langchain_community.llms import Ollama
154
+
155
+ self.query_expansion_model = Ollama(model=QUERY_EXPANSION_MODEL)
156
+ print(f"✅ 查询扩展初始化成功 (模型: {QUERY_EXPANSION_MODEL})")
157
+ except Exception as e:
158
+ print(f"⚠️ 查询扩展初始化失败: {e}")
159
+ print("⚠️ 将不使用查询扩展")
160
+ self.query_expansion_model = None
161
+
162
  def load_documents(self, urls=None):
163
  """从URL加载文档"""
164
  if urls is None:
 
186
  embedding=self.embeddings,
187
  )
188
  self.retriever = self.vectorstore.as_retriever()
189
+
190
+ # 如果启用混合检索,创建BM25检索器和集成检索器
191
+ if ENABLE_HYBRID_SEARCH:
192
+ print("正在初始化混合检索...")
193
+ try:
194
+ # 创建BM25检索器
195
+ self.bm25_retriever = BM25Retriever.from_documents(
196
+ doc_splits,
197
+ k=KEYWORD_SEARCH_K,
198
+ k1=BM25_K1,
199
+ b=BM25_B
200
+ )
201
+
202
+ # 创建集成检索器,结合向量检索和BM25检索
203
+ self.ensemble_retriever = EnsembleRetriever(
204
+ retrievers=[self.retriever, self.bm25_retriever],
205
+ weights=[HYBRID_SEARCH_WEIGHTS["vector"], HYBRID_SEARCH_WEIGHTS["keyword"]]
206
+ )
207
+ print("✅ 混合检索初始化成功")
208
+ except Exception as e:
209
+ print(f"⚠️ 混合检索初始化失败: {e}")
210
+ print("⚠️ 将仅使用向量检索")
211
+ self.ensemble_retriever = None
212
+
213
  print("向量数据库创建完成")
214
  return self.vectorstore, self.retriever
215
 
 
230
  # 返回doc_splits用于GraphRAG索引
231
  return vectorstore, retriever, doc_splits
232
 
233
+ def expand_query(self, query: str) -> List[str]:
234
+ """扩展查询,生成相关查询"""
235
+ if not self.query_expansion_model:
236
+ return [query]
237
+
238
+ try:
239
+ # 使用LLM生成扩展查询
240
+ prompt = QUERY_EXPANSION_PROMPT.format(query=query)
241
+ expanded_queries_text = self.query_expansion_model.invoke(prompt)
242
+
243
+ # 解析扩展查询
244
+ expanded_queries = [query] # 包含原始查询
245
+ for line in expanded_queries_text.strip().split('\n'):
246
+ line = line.strip()
247
+ if line and not line.startswith('#') and not line.startswith('//'):
248
+ # 移除可能的编号前缀
249
+ if line[0].isdigit() and '.' in line[:5]:
250
+ line = line.split('.', 1)[1].strip()
251
+ expanded_queries.append(line)
252
+
253
+ # 限制扩展查询数量
254
+ return expanded_queries[:MAX_EXPANDED_QUERIES + 1] # +1 因为包含原始查询
255
+ except Exception as e:
256
+ print(f"⚠️ 查询扩展失败: {e}")
257
+ return [query]
258
+
259
+ def encode_image(self, image_path: str) -> np.ndarray:
260
+ """编码图像为嵌入向量"""
261
+ if not self.image_embeddings_model:
262
+ raise ValueError("多模态支持未初始化")
263
+
264
+ try:
265
+ # 加载并处理图像
266
+ image = Image.open(image_path).convert('RGB')
267
+ inputs = self.image_processor(images=image, return_tensors="pt")
268
+
269
+ # 获取图像嵌入
270
+ with torch.no_grad():
271
+ image_features = self.image_embeddings_model.get_image_features(**inputs)
272
+ # 标准化嵌入向量
273
+ image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
274
+
275
+ return image_features.cpu().numpy().flatten()
276
+ except Exception as e:
277
+ print(f"⚠️ 图像编码失败: {e}")
278
+ raise
279
+
280
+ def multimodal_retrieve(self, query: str, image_paths: List[str] = None, top_k: int = 5) -> List:
281
+ """多模态检索,结合文本和图像"""
282
+ if not ENABLE_MULTIMODAL or not self.image_embeddings_model:
283
+ # 如果多模态未启用,回退到文本检索
284
+ return self.hybrid_retrieve(query, top_k) if ENABLE_HYBRID_SEARCH else self.retriever.invoke(query)[:top_k]
285
 
286
+ # 文本检索
287
+ text_docs = self.hybrid_retrieve(query, top_k) if ENABLE_HYBRID_SEARCH else self.retriever.invoke(query)[:top_k]
288
 
289
+ # 如果没有提供图像,直接返回文本检索结果
290
+ if not image_paths:
291
+ return text_docs
292
+
293
+ try:
294
+ # 图像检索
295
+ image_results = []
296
+ for image_path in image_paths:
297
+ # 检查文件格式
298
+ file_ext = image_path.split('.')[-1].lower()
299
+ if file_ext not in SUPPORTED_IMAGE_FORMATS:
300
+ print(f"⚠️ 不支持的图像格式: {file_ext}")
301
+ continue
302
+
303
+ # 编码图像
304
+ image_embedding = self.encode_image(image_path)
305
+
306
+ # 这里应该实现图像到文本的匹配逻辑
307
+ # 由于原始实现中没有图像数据库,我们简化处理
308
+ # 在实际应用中,应该有一个图像数据库和相应的检索逻辑
309
+
310
+ # 合并文本和图像结果(简化版本)
311
+ # 在实际应用中,应该有更复杂的融合逻辑
312
+ final_docs = text_docs # 简化版本,仅返回文本结果
313
+
314
+ print(f"✅ 多���态检索完成,返回 {len(final_docs)} 个结果")
315
+ return final_docs
316
+ except Exception as e:
317
+ print(f"⚠️ 多模态检索失败: {e}")
318
+ print("回退到文本检索")
319
+ return text_docs
320
+
321
+ def hybrid_retrieve(self, query: str, top_k: int = 5) -> List:
322
+ """混合检索,结合向量检索和关键词检索"""
323
+ if not ENABLE_HYBRID_SEARCH or not self.ensemble_retriever:
324
+ # 如果混合检索未启用,回退到向量检索
325
+ return self.retriever.invoke(query)[:top_k]
326
+
327
+ try:
328
+ # 使用集成检索器进行混合检索
329
+ results = self.ensemble_retriever.invoke(query)
330
+ return results[:top_k]
331
+ except Exception as e:
332
+ print(f"⚠️ 混合检索失败: {e}")
333
+ print("回退到向量检索")
334
+ return self.retriever.invoke(query)[:top_k]
335
+
336
+ def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20,
337
+ image_paths: List[str] = None, use_query_expansion: bool = None):
338
+ """增强检索:先检索更多候选,然后重排,支持查询扩展和多模态
339
+
340
+ Args:
341
+ query: 查询字符串
342
+ top_k: 返回的文档数量
343
+ rerank_candidates: 重排前的候选文档数量
344
+ image_paths: 图像路径列表,用于多模态检索
345
+ use_query_expansion: 是否使用查询扩展,None表示使用配置默认值
346
+ """
347
+ # 确定是否使用查询扩展
348
+ if use_query_expansion is None:
349
+ use_query_expansion = ENABLE_QUERY_EXPANSION
350
+
351
+ # 如果启用查询扩展,生成扩展查询
352
+ if use_query_expansion:
353
+ expanded_queries = self.expand_query(query)
354
+ print(f"查询扩展: {len(expanded_queries)} 个查询")
355
+ else:
356
+ expanded_queries = [query]
357
+
358
+ # 多模态检索(如果提供了图像)
359
+ if image_paths and ENABLE_MULTIMODAL:
360
+ return self.multimodal_retrieve(query, image_paths, top_k)
361
+
362
+ # 混合检索或向量检索
363
+ all_candidate_docs = []
364
+ for expanded_query in expanded_queries:
365
+ if ENABLE_HYBRID_SEARCH:
366
+ # 使用混合检索
367
+ docs = self.hybrid_retrieve(expanded_query, rerank_candidates)
368
+ else:
369
+ # 使用向量检索
370
+ docs = self.retriever.invoke(expanded_query)
371
+ if len(docs) > rerank_candidates:
372
+ docs = docs[:rerank_candidates]
373
+
374
+ all_candidate_docs.extend(docs)
375
+
376
+ # 去重(基于文档内容)
377
+ unique_docs = []
378
+ seen_content = set()
379
+ for doc in all_candidate_docs:
380
+ content = doc.page_content
381
+ if content not in seen_content:
382
+ seen_content.add(content)
383
+ unique_docs.append(doc)
384
+
385
+ print(f"检索获得 {len(unique_docs)} 个候选文档")
386
+
387
+ # 重排(如果重排器可用)
388
+ if self.reranker and len(unique_docs) > top_k:
389
  try:
390
+ reranked_results = self.reranker.rerank(query, unique_docs, top_k)
391
  final_docs = [doc for doc, score in reranked_results]
392
  scores = [score for doc, score in reranked_results]
393
 
 
397
  return final_docs
398
  except Exception as e:
399
  print(f"⚠️ 重排失败: {e},使用原始检索结果")
400
+ return unique_docs[:top_k]
401
  else:
402
  # 不重排或候选数量不足
403
+ return unique_docs[:top_k]
404
 
405
+ def compare_retrieval_methods(self, query: str, top_k: int = 5, image_paths: List[str] = None):
406
  """比较不同检索方法的效果"""
407
  if not self.retriever:
408
  return {}
409
 
410
+ results = {
411
+ 'query': query,
412
+ 'image_paths': image_paths
413
+ }
414
+
415
  # 原始检索 (使用 invoke 替代 get_relevant_documents)
416
  original_docs = self.retriever.invoke(query)[:top_k]
417
+ results['vector_retrieval'] = {
418
+ 'count': len(original_docs),
419
+ 'documents': [{
420
+ 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
421
+ 'metadata': getattr(doc, 'metadata', {})
422
+ } for doc in original_docs]
423
+ }
424
 
425
+ # 混合检索(如果启用)
426
+ if ENABLE_HYBRID_SEARCH and self.ensemble_retriever:
427
+ hybrid_docs = self.hybrid_retrieve(query, top_k)
428
+ results['hybrid_retrieval'] = {
429
+ 'count': len(hybrid_docs),
430
+ 'documents': [{
431
+ 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
432
+ 'metadata': getattr(doc, 'metadata', {})
433
+ } for doc in hybrid_docs]
434
+ }
435
 
436
+ # 查询扩展检索(如果启用)
437
+ if ENABLE_QUERY_EXPANSION and self.query_expansion_model:
438
+ expanded_docs = self.enhanced_retrieve(query, top_k, use_query_expansion=True)
439
+ results['expanded_query_retrieval'] = {
440
+ 'count': len(expanded_docs),
441
  'documents': [{
442
  'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
443
  'metadata': getattr(doc, 'metadata', {})
444
+ } for doc in expanded_docs]
445
+ }
446
+
447
+ # 多模态检索(如果启用且有图像)
448
+ if ENABLE_MULTIMODAL and image_paths:
449
+ multimodal_docs = self.multimodal_retrieve(query, image_paths, top_k)
450
+ results['multimodal_retrieval'] = {
451
+ 'count': len(multimodal_docs),
452
  'documents': [{
453
  'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
454
  'metadata': getattr(doc, 'metadata', {})
455
+ } for doc in multimodal_docs]
456
+ }
457
+
458
+ # 增强检索(带重排)
459
+ enhanced_docs = self.enhanced_retrieve(query, top_k)
460
+ results['enhanced_retrieval'] = {
461
+ 'count': len(enhanced_docs),
462
+ 'documents': [{
463
+ 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
464
+ 'metadata': getattr(doc, 'metadata', {})
465
+ } for doc in enhanced_docs]
466
+ }
467
+
468
+ # 添加配置信息
469
+ results['configuration'] = {
470
+ 'hybrid_search_enabled': ENABLE_HYBRID_SEARCH,
471
+ 'query_expansion_enabled': ENABLE_QUERY_EXPANSION,
472
+ 'multimodal_enabled': ENABLE_MULTIMODAL,
473
+ 'reranker_used': self.reranker is not None,
474
+ 'hybrid_weights': HYBRID_SEARCH_WEIGHTS if ENABLE_HYBRID_SEARCH else None,
475
+ 'multimodal_weights': MULTIMODAL_WEIGHTS if ENABLE_MULTIMODAL else None
476
  }
477
+
478
+ return results
479
 
480
  def format_docs(self, docs):
481
  """格式化文档用于生成"""
main.py CHANGED
@@ -47,6 +47,12 @@ class AdaptiveRAGSystem:
47
 
48
  def _build_workflow(self):
49
  """构建工作流图"""
 
 
 
 
 
 
50
  workflow = StateGraph(GraphState)
51
 
52
  # 定义节点
 
47
 
48
  def _build_workflow(self):
49
  """构建工作流图"""
50
+ # 创建工作流节点实例,传递DocumentProcessor实例
51
+ self.workflow_nodes = WorkflowNodes(
52
+ doc_processor=self.doc_processor,
53
+ graders=self.graders
54
+ )
55
+
56
  workflow = StateGraph(GraphState)
57
 
58
  # 定义节点
requirements.txt CHANGED
@@ -20,6 +20,7 @@ transformers>=4.30.0
20
  tiktoken>=0.5.0
21
  beautifulsoup4>=4.12.0
22
  requests>=2.31.0
 
23
 
24
  # 幻觉检测
25
  sentence-transformers>=2.2.0 # NLI 模型支持
 
20
  tiktoken>=0.5.0
21
  beautifulsoup4>=4.12.0
22
  requests>=2.31.0
23
+ Pillow>=9.0.0 # 图像处理,支持多模态功能
24
 
25
  # 幻觉检测
26
  sentence-transformers>=2.2.0 # NLI 模型支持
setup_and_run.py CHANGED
@@ -44,6 +44,8 @@ def setup_environment():
44
  if current_dir not in sys.path:
45
  sys.path.insert(0, current_dir)
46
  print(f"\n ✅ 已添加到 Python 路径: {current_dir}")
 
 
47
 
48
  # ============================================================
49
  # 2. 运行 main_graphrag.py
 
44
  if current_dir not in sys.path:
45
  sys.path.insert(0, current_dir)
46
  print(f"\n ✅ 已添加到 Python 路径: {current_dir}")
47
+
48
+ print("\n 💡 注意: 新增的多模态功能需要Pillow库,请确保已安装")
49
 
50
  # ============================================================
51
  # 2. 运行 main_graphrag.py
workflow_nodes.py CHANGED
@@ -17,7 +17,8 @@ try:
17
  except ImportError:
18
  from langchain.prompts import PromptTemplate
19
 
20
- from config import LOCAL_LLM, WEB_SEARCH_RESULTS_COUNT
 
21
  from pprint import pprint
22
 
23
 
@@ -38,8 +39,9 @@ class GraphState(TypedDict):
38
  class WorkflowNodes:
39
  """工作流节点类,包含所有节点函数"""
40
 
41
- def __init__(self, retriever, graders):
42
- self.retriever = retriever
 
43
  self.graders = graders
44
 
45
  # 设置RAG链 - 使用本地提示模板
@@ -73,8 +75,33 @@ class WorkflowNodes:
73
  print("---检索---")
74
  question = state["question"]
75
 
76
- # 检索 (使用 invoke 替代 get_relevant_documents)
77
- documents = self.retriever.invoke(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return {"documents": documents, "question": question}
79
 
80
  def generate(self, state):
 
17
  except ImportError:
18
  from langchain.prompts import PromptTemplate
19
 
20
+ from config import LOCAL_LLM, WEB_SEARCH_RESULTS_COUNT, ENABLE_HYBRID_SEARCH, ENABLE_QUERY_EXPANSION, ENABLE_MULTIMODAL
21
+ from document_processor import DocumentProcessor
22
  from pprint import pprint
23
 
24
 
 
39
  class WorkflowNodes:
40
  """工作流节点类,包含所有节点函数"""
41
 
42
+ def __init__(self, doc_processor, graders):
43
+ self.doc_processor = doc_processor # 接收DocumentProcessor实例
44
+ self.retriever = doc_processor.retriever # 保持向后兼容
45
  self.graders = graders
46
 
47
  # 设置RAG链 - 使用本地提示模板
 
75
  print("---检索---")
76
  question = state["question"]
77
 
78
+ # 使用增强检索方法,支持混合检索、查询扩展和多模态
79
+ try:
80
+ # 检查是否有图像路径(多模态检索)
81
+ image_paths = state.get("image_paths", None)
82
+
83
+ # 使用增强检索
84
+ documents = self.doc_processor.enhanced_retrieve(
85
+ question,
86
+ top_k=5,
87
+ rerank_candidates=20,
88
+ image_paths=image_paths,
89
+ use_query_expansion=ENABLE_QUERY_EXPANSION
90
+ )
91
+
92
+ # 记录使用的检索方法
93
+ if ENABLE_HYBRID_SEARCH:
94
+ print("---使用混合检索---")
95
+ if ENABLE_QUERY_EXPANSION:
96
+ print("---使用查询扩展---")
97
+ if image_paths and ENABLE_MULTIMODAL:
98
+ print("---使用多模态检索---")
99
+
100
+ except Exception as e:
101
+ print(f"⚠️ 增强检索失败: {e},回退到基本检索")
102
+ # 回退到基本检索
103
+ documents = self.retriever.invoke(question)
104
+
105
  return {"documents": documents, "question": question}
106
 
107
  def generate(self, state):