adaptive_rag / test_crossencoder_reranking.py
lanny xu
modify reranker
db5bfaa
raw
history blame
8.37 kB
"""
ๆต‹่ฏ• CrossEncoder ้‡ๆŽ’ๅŠŸ่ƒฝ
ๅฏนๆฏ” Bi-Encoder vs CrossEncoder ็š„ๆ•ˆๆžœ
"""
from reranker import create_reranker, TFIDFReranker, BM25Reranker, SemanticReranker, CrossEncoderReranker
class MockDoc:
"""ๆจกๆ‹Ÿๆ–‡ๆกฃ็ฑป"""
def __init__(self, content, metadata=None):
self.page_content = content
self.metadata = metadata or {}
class MockEmbeddings:
"""ๆจกๆ‹Ÿ Embeddings ็ฑป๏ผˆ็”จไบŽ Semantic Reranker๏ผ‰"""
def embed_query(self, text):
# ็ฎ€ๅ•็š„ๅญ—็ฌฆ็บงๅ‘้‡ๅŒ–๏ผˆไป…็”จไบŽๆต‹่ฏ•๏ผ‰
return [ord(c) / 100.0 for c in text[:10]]
def embed_documents(self, texts):
return [self.embed_query(text) for text in texts]
def create_test_documents():
"""ๅˆ›ๅปบๆต‹่ฏ•ๆ–‡ๆกฃ้›†"""
return [
MockDoc("ไบบๅทฅๆ™บ่ƒฝๆ˜ฏ่ฎก็ฎ—ๆœบ็ง‘ๅญฆ็š„ไธ€ไธชๅˆ†ๆ”ฏ๏ผŒ่‡ดๅŠ›ไบŽๅˆ›ๅปบ่ƒฝๅคŸๆ‰ง่กŒ้€šๅธธ้œ€่ฆไบบ็ฑปๆ™บ่ƒฝ็š„ไปปๅŠก็š„็ณป็ปŸใ€‚"),
MockDoc("ๆœบๅ™จๅญฆไน ๆ˜ฏไบบๅทฅๆ™บ่ƒฝ็š„ๅญ้ข†ๅŸŸ๏ผŒไธ“ๆณจไบŽ่ฎฉ่ฎก็ฎ—ๆœบไปŽๆ•ฐๆฎไธญๅญฆไน ๅนถๆ”น่ฟ›ใ€‚"),
MockDoc("ๆทฑๅบฆๅญฆไน ไฝฟ็”จๅคšๅฑ‚็ฅž็ป็ฝ‘็ปœๆฅๅค„็†ๅคๆ‚็š„ๆ•ฐๆฎๆจกๅผ๏ผŒๆ˜ฏๆœบๅ™จๅญฆไน ็š„ไธ€็งๆ–นๆณ•ใ€‚"),
MockDoc("่‡ช็„ถ่ฏญ่จ€ๅค„็†๏ผˆNLP๏ผ‰ๆ˜ฏไบบๅทฅๆ™บ่ƒฝ็š„ไธ€ไธชๅˆ†ๆ”ฏ๏ผŒๅค„็†่ฎก็ฎ—ๆœบไธŽไบบ็ฑป่ฏญ่จ€ไน‹้—ด็š„ไบคไบ’ใ€‚"),
MockDoc("่ฎก็ฎ—ๆœบ่ง†่ง‰ๆ˜ฏไบบๅทฅๆ™บ่ƒฝ็š„ๅฆไธ€ไธช้‡่ฆ้ข†ๅŸŸ๏ผŒไฝฟๆœบๅ™จ่ƒฝๅคŸ็†่งฃๅ’Œ่งฃ้‡Š่ง†่ง‰ไฟกๆฏใ€‚"),
MockDoc("ไปŠๅคฉๅคฉๆฐ”ๅพˆๅฅฝ๏ผŒ้€‚ๅˆๅ‡บๅŽปๆ•ฃๆญฅๅ’Œ่ฟๅŠจใ€‚"),
MockDoc("Python ๆ˜ฏไธ€็ง้ซ˜็บง็ผ–็จ‹่ฏญ่จ€๏ผŒ็”ฑ Guido van Rossum ๅœจ 1991 ๅนดๅˆ›ๅปบใ€‚"),
MockDoc("RAG๏ผˆๆฃ€็ดขๅขžๅผบ็”Ÿๆˆ๏ผ‰ๆ˜ฏไธ€็ง็ป“ๅˆไฟกๆฏๆฃ€็ดขๅ’Œๆ–‡ๆœฌ็”Ÿๆˆ็š„ๆŠ€ๆœฏใ€‚"),
]
def test_tfidf_reranking():
"""ๆต‹่ฏ• TF-IDF ้‡ๆŽ’"""
print("\n" + "=" * 60)
print("๐Ÿ“Š ๆต‹่ฏ• TF-IDF ้‡ๆŽ’")
print("=" * 60)
query = "ไป€ไนˆๆ˜ฏไบบๅทฅๆ™บ่ƒฝๅ’Œๆœบๅ™จๅญฆไน ๏ผŸ"
docs = create_test_documents()
reranker = TFIDFReranker()
results = reranker.rerank(query, docs, top_k=3)
print(f"\nๆŸฅ่ฏข: {query}")
print("\nTF-IDF ้‡ๆŽ’็ป“ๆžœ:")
for i, (doc, score) in enumerate(results, 1):
print(f"{i}. ๅˆ†ๆ•ฐ: {score:.4f} | ๅ†…ๅฎน: {doc.page_content[:50]}...")
def test_bm25_reranking():
"""ๆต‹่ฏ• BM25 ้‡ๆŽ’"""
print("\n" + "=" * 60)
print("๐Ÿ“Š ๆต‹่ฏ• BM25 ้‡ๆŽ’")
print("=" * 60)
query = "ไป€ไนˆๆ˜ฏไบบๅทฅๆ™บ่ƒฝๅ’Œๆœบๅ™จๅญฆไน ๏ผŸ"
docs = create_test_documents()
reranker = BM25Reranker()
results = reranker.rerank(query, docs, top_k=3)
print(f"\nๆŸฅ่ฏข: {query}")
print("\nBM25 ้‡ๆŽ’็ป“ๆžœ:")
for i, (doc, score) in enumerate(results, 1):
print(f"{i}. ๅˆ†ๆ•ฐ: {score:.4f} | ๅ†…ๅฎน: {doc.page_content[:50]}...")
def test_crossencoder_reranking():
"""ๆต‹่ฏ• CrossEncoder ้‡ๆŽ’"""
print("\n" + "=" * 60)
print("๐ŸŒŸ ๆต‹่ฏ• CrossEncoder ้‡ๆŽ’๏ผˆๆŽจ่๏ผ‰")
print("=" * 60)
query = "ไป€ไนˆๆ˜ฏไบบๅทฅๆ™บ่ƒฝๅ’Œๆœบๅ™จๅญฆไน ๏ผŸ"
docs = create_test_documents()
try:
# ไฝฟ็”จ่ฝป้‡็บงๆจกๅž‹
reranker = CrossEncoderReranker(
model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"
)
results = reranker.rerank(query, docs, top_k=3)
print(f"\nๆŸฅ่ฏข: {query}")
print("\nCrossEncoder ้‡ๆŽ’็ป“ๆžœ:")
for i, (doc, score) in enumerate(results, 1):
print(f"{i}. ๅˆ†ๆ•ฐ: {score:.4f} | ๅ†…ๅฎน: {doc.page_content[:50]}...")
return True
except Exception as e:
print(f"\nโŒ CrossEncoder ๆต‹่ฏ•ๅคฑ่ดฅ: {e}")
print("๐Ÿ’ก ๆ็คบ: ่ฏทๅ…ˆๅฎ‰่ฃ… sentence-transformers")
print(" ๅ‘ฝไปค: pip install sentence-transformers")
return False
def test_factory_function():
"""ๆต‹่ฏ•ๅทฅๅŽ‚ๅ‡ฝๆ•ฐ"""
print("\n" + "=" * 60)
print("๐Ÿญ ๆต‹่ฏ•้‡ๆŽ’ๅ™จๅทฅๅŽ‚ๅ‡ฝๆ•ฐ")
print("=" * 60)
query = "ๆทฑๅบฆๅญฆไน ๅ’Œ็ฅž็ป็ฝ‘็ปœ"
docs = create_test_documents()
# ๆต‹่ฏ•ๅ„็ง็ฑปๅž‹
reranker_types = ['tfidf', 'bm25']
for rtype in reranker_types:
try:
reranker = create_reranker(rtype)
results = reranker.rerank(query, docs, top_k=2)
print(f"\nโœ… {rtype.upper()} ้‡ๆŽ’ๅ™จๅˆ›ๅปบๆˆๅŠŸ")
print(f" Top 1: {results[0][1]:.4f} | {results[0][0].page_content[:40]}...")
except Exception as e:
print(f"\nโŒ {rtype.upper()} ้‡ๆŽ’ๅ™จๅคฑ่ดฅ: {e}")
# ๆต‹่ฏ• CrossEncoder
try:
reranker = create_reranker('crossencoder')
results = reranker.rerank(query, docs, top_k=2)
print(f"\nโœ… CROSSENCODER ้‡ๆŽ’ๅ™จๅˆ›ๅปบๆˆๅŠŸ")
print(f" Top 1: {results[0][1]:.4f} | {results[0][0].page_content[:40]}...")
except Exception as e:
print(f"\nโŒ CROSSENCODER ้‡ๆŽ’ๅ™จๅคฑ่ดฅ: {e}")
def compare_all_methods():
"""ๅฏนๆฏ”ๆ‰€ๆœ‰้‡ๆŽ’ๆ–นๆณ•"""
print("\n" + "=" * 60)
print("โš–๏ธ ๅฏนๆฏ”ๆ‰€ๆœ‰้‡ๆŽ’ๆ–นๆณ•")
print("=" * 60)
query = "่งฃ้‡Šไธ€ไธ‹ไบบๅทฅๆ™บ่ƒฝใ€ๆœบๅ™จๅญฆไน ๅ’Œๆทฑๅบฆๅญฆไน ็š„ๅ…ณ็ณป"
docs = create_test_documents()
methods = {
'TF-IDF': TFIDFReranker(),
'BM25': BM25Reranker(),
}
# ๅฐ่ฏ•ๆทปๅŠ  CrossEncoder
try:
methods['CrossEncoder'] = CrossEncoderReranker()
except:
print("\nโš ๏ธ CrossEncoder ไธๅฏ็”จ๏ผŒ่ทณ่ฟ‡")
print(f"\nๆŸฅ่ฏข: {query}\n")
for method_name, reranker in methods.items():
try:
results = reranker.rerank(query, docs, top_k=3)
print(f"\n{'=' * 40}")
print(f"{method_name} ้‡ๆŽ’็ป“ๆžœ:")
print('=' * 40)
for i, (doc, score) in enumerate(results, 1):
print(f"{i}. [{score:.4f}] {doc.page_content[:60]}...")
except Exception as e:
print(f"\n{method_name} ๅคฑ่ดฅ: {e}")
def performance_comparison():
"""ๆ€ง่ƒฝๅฏนๆฏ”"""
print("\n" + "=" * 60)
print("โšก ๆ€ง่ƒฝไธŽๅ‡†็กฎๆ€งๅฏนๆฏ”")
print("=" * 60)
print("""
้‡ๆŽ’ๆ–นๆณ•ๅฏนๆฏ”๏ผš
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ ๆ–นๆณ• โ”‚ ๅ‡†็กฎ็އ โ”‚ ้€Ÿๅบฆ โ”‚ ๆˆๆœฌ โ”‚ ้€‚็”จๅœบๆ™ฏ โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ TF-IDF โ”‚ โญโญ โ”‚ โšกโšกโšก โ”‚ ๆžไฝŽ โ”‚ ๅ…ณ้”ฎ่ฏๅŒน้… โ”‚
โ”‚ BM25 โ”‚ โญโญโญ โ”‚ โšกโšกโšก โ”‚ ๆžไฝŽ โ”‚ ๆ–‡ๆœฌๆฃ€็ดข โ”‚
โ”‚ Bi-Encoder โ”‚ โญโญโญโญ โ”‚ โšกโšก โ”‚ ไฝŽ โ”‚ ่ฏญไน‰ๆฃ€็ดข โ”‚
โ”‚ CrossEncoder ๐ŸŒŸ โ”‚ โญโญโญโญโญโ”‚ โšก โ”‚ ไธญ โ”‚ ็ฒพๅ‡†้‡ๆŽ’ โ”‚
โ”‚ Hybrid โ”‚ โญโญโญโญ โ”‚ โšกโšก โ”‚ ไฝŽ โ”‚ ็ปผๅˆๅœบๆ™ฏ โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
ๆŽจ่้…็ฝฎ๏ผš
1๏ธโƒฃ ไธค้˜ถๆฎตๆฃ€็ดข๏ผšBi-Encoder (ๅฟซ้€Ÿๅฌๅ›ž) + CrossEncoder (็ฒพๅ‡†้‡ๆŽ’)
2๏ธโƒฃ ๅ‡†็กฎ็އไผ˜ๅ…ˆ๏ผš็บฏ CrossEncoder
3๏ธโƒฃ ้€Ÿๅบฆไผ˜ๅ…ˆ๏ผšBM25 ๆˆ– Hybrid
ๅฝ“ๅ‰้กน็›ฎ้…็ฝฎ๏ผš
โœ… ๅทฒๅˆ‡ๆขๅˆฐ CrossEncoder ้‡ๆŽ’
๐Ÿ“ˆ ๅ‡†็กฎ็އ้ข„ๆœŸๆๅ‡๏ผš15-20%
โšก ้€Ÿๅบฆ๏ผšๅ•ๆฌก้‡ๆŽ’ 20-100ms (Top 20 ๆ–‡ๆกฃ)
""")
if __name__ == "__main__":
print("\n๐Ÿš€ ๅผ€ๅง‹ๆต‹่ฏ• CrossEncoder ้‡ๆŽ’ๅŠŸ่ƒฝ...\n")
# 1. ๆต‹่ฏ• TF-IDF
test_tfidf_reranking()
# 2. ๆต‹่ฏ• BM25
test_bm25_reranking()
# 3. ๆต‹่ฏ• CrossEncoder (้‡็‚น)
crossencoder_available = test_crossencoder_reranking()
# 4. ๆต‹่ฏ•ๅทฅๅŽ‚ๅ‡ฝๆ•ฐ
test_factory_function()
# 5. ๅฏนๆฏ”ๆ‰€ๆœ‰ๆ–นๆณ•
compare_all_methods()
# 6. ๆ€ง่ƒฝๅฏนๆฏ”ๆ€ป็ป“
performance_comparison()
print("\n" + "=" * 60)
if crossencoder_available:
print("โœ… ๆ‰€ๆœ‰ๆต‹่ฏ•ๅฎŒๆˆ๏ผCrossEncoder ้‡ๆŽ’ๅทฒๅฐฑ็ปช")
else:
print("โš ๏ธ ๆต‹่ฏ•ๅฎŒๆˆ๏ผŒไฝ† CrossEncoder ไธๅฏ็”จ")
print(" ่ฏท่ฟ่กŒ: pip install sentence-transformers")
print("=" * 60 + "\n")