""" Entity Reranker Improves precision by re-scoring and reordering retrieval results """ from typing import List from .models import RetrievalResult from .normalizer import get_normalizer class EntityReranker: """ Reranks retrieval results to improve precision. Scoring factors: - Full name match (highest priority) - Name length ratio (prefer specific over generic) - Position/organization match - Penalize ambiguous short matches """ def __init__(self, debug: bool = False): self.debug = debug self.normalizer = get_normalizer() def rerank(self, query: str, results: List[RetrievalResult]) -> List[RetrievalResult]: """ Rerank results based on query-specific signals. Args: query: Original user query results: Retrieved results to rerank Returns: Reranked list of results """ if not results: return results # Normalize query for comparison query_norm = self.normalizer.normalize(query) query_stem = self.normalizer.stem(query_norm) query_parts = self.normalizer.extract_name_parts(query) # Score adjustments scored_results = [] for result in results: score = result.score entity = result.entity # Get entity's normalized form entity_norm = self.normalizer.normalize(entity.name) entity_stem = self.normalizer.stem(entity_norm) # Factor 1: Exact match bonus if query_norm == entity_norm or query_stem == entity_stem: score *= 10.0 if self.debug: print(f" [Rerank] Exact match: {entity.name} → ×10") # Factor 2: Query is substring of entity (specific query) elif query_norm in entity_norm or query_stem in entity_stem: # Length ratio bonus (longer query = more specific) ratio = len(query_norm) / len(entity_norm) bonus = 1.0 + (ratio * 3.0) # Up to 4x for long matches score *= bonus if self.debug: print(f" [Rerank] Substring match: {entity.name} → ×{bonus:.2f}") # Factor 3: Entity is substring of query (query contains name) elif entity_norm in query_norm or entity_stem in query_stem: score *= 3.0 if self.debug: print(f" [Rerank] Entity in query: {entity.name} → ×3") # Factor 4: First name match (important for Arabic names) if query_parts.get("first_name"): entity_parts = self.normalizer.extract_name_parts(entity.name) if query_parts["first_name"] == entity_parts.get("first_name"): score *= 1.5 if self.debug: print(f" [Rerank] First name match: {entity.name} → ×1.5") # Factor 5: Penalize very short matches (likely ambiguous) if len(query_norm) < 15 and len(entity_norm) > 40: # Short query matching long name - might be too generic penalty = 0.7 score *= penalty if self.debug: print(f" [Rerank] Short query penalty: {entity.name} → ×{penalty}") # Factor 6: Position/Organization match bonus query_lower = query.lower() if entity.primary_position and entity.primary_position.lower() in query_lower: score *= 2.0 if self.debug: print(f" [Rerank] Position match: {entity.primary_position} → ×2") if entity.primary_organization and entity.primary_organization.lower() in query_lower: score *= 1.5 if self.debug: print(f" [Rerank] Organization match: {entity.primary_organization} → ×1.5") # Create new result with adjusted score scored_results.append(RetrievalResult( entity=result.entity, score=score, match_type=result.match_type, matched_variant=result.matched_variant, normalized_query=result.normalized_query, )) # Sort by adjusted score scored_results.sort(key=lambda r: r.score, reverse=True) return scored_results def filter_ambiguous(self, results: List[RetrievalResult], threshold: float = 0.2) -> List[RetrievalResult]: """ Filter results where top scores are too close (ambiguous). Args: results: Ranked results threshold: Relative difference threshold (0.2 = 20%) Returns: Filtered results, or original if not ambiguous """ if len(results) < 2: return results top_score = results[0].score second_score = results[1].score # If scores are very close, might be ambiguous if top_score > 0 and (top_score - second_score) / top_score < threshold: if self.debug: print(f" [Rerank] Ambiguous results detected:") print(f" #1: {results[0].entity.name} (score: {top_score:.2f})") print(f" #2: {results[1].entity.name} (score: {second_score:.2f})") # Return all close results for further disambiguation close_results = [r for r in results if r.score >= top_score * (1 - threshold * 2)] return close_results return results # Convenience function def rerank_results(query: str, results: List[RetrievalResult], debug: bool = False) -> List[RetrievalResult]: """ Convenience function to rerank results. """ reranker = EntityReranker(debug=debug) return reranker.rerank(query, results)