lanny xu commited on
Commit
a1ec589
·
1 Parent(s): fa26a24

modifies bug

Browse files
Files changed (1) hide show
  1. document_processor.py +36 -1
document_processor.py CHANGED
@@ -48,6 +48,41 @@ import numpy as np
48
  from typing import List, Dict, Any, Optional, Union
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  class DocumentProcessor:
52
  """文档处理器类,负责文档加载、处理和向量化"""
53
 
@@ -200,7 +235,7 @@ class DocumentProcessor:
200
  )
201
 
202
  # 创建集成检索器,结合向量检索和BM25检索
203
- self.ensemble_retriever = EnsembleRetriever(
204
  retrievers=[self.retriever, self.bm25_retriever],
205
  weights=[HYBRID_SEARCH_WEIGHTS["vector"], HYBRID_SEARCH_WEIGHTS["keyword"]]
206
  )
 
48
  from typing import List, Dict, Any, Optional, Union
49
 
50
 
51
+ class CustomEnsembleRetriever:
52
+ """自定义集成检索器,结合向量检索和BM25检索"""
53
+
54
+ def __init__(self, retrievers, weights):
55
+ self.retrievers = retrievers
56
+ self.weights = weights
57
+
58
+ def invoke(self, query):
59
+ """执行检索并合并结果"""
60
+ # 获取各检索器的结果
61
+ all_results = []
62
+ for i, retriever in enumerate(self.retrievers):
63
+ results = retriever.invoke(query)
64
+ for doc in results:
65
+ # 添加检索器索引和权重信息
66
+ doc.metadata["retriever_index"] = i
67
+ doc.metadata["retriever_weight"] = self.weights[i]
68
+ all_results.append(doc)
69
+
70
+ # 根据权重排序并去重
71
+ # 简单实现:先按检索器索引排序,再按权重排序
72
+ all_results.sort(key=lambda x: (x.metadata["retriever_index"], -x.metadata["retriever_weight"]))
73
+
74
+ # 去重(基于文档内容)
75
+ unique_results = []
76
+ seen_content = set()
77
+ for doc in all_results:
78
+ content = doc.page_content
79
+ if content not in seen_content:
80
+ seen_content.add(content)
81
+ unique_results.append(doc)
82
+
83
+ return unique_results
84
+
85
+
86
  class DocumentProcessor:
87
  """文档处理器类,负责文档加载、处理和向量化"""
88
 
 
235
  )
236
 
237
  # 创建集成检索器,结合向量检索和BM25检索
238
+ self.ensemble_retriever = CustomEnsembleRetriever(
239
  retrievers=[self.retriever, self.bm25_retriever],
240
  weights=[HYBRID_SEARCH_WEIGHTS["vector"], HYBRID_SEARCH_WEIGHTS["keyword"]]
241
  )