xingzhi-lu-epq / app.py
Lxz20071231's picture
Restructured codebase
1edca62
import gradio as gr
import fasttext
from huggingface_hub import hf_hub_download
from collections import defaultdict
from utils.consts import topics, topics_full
from models.bayes import NaiveBayesMultiClass, get_tags_bayes
from models.lstm import LSTMPipeline, get_tags_lstm
from models.bert import get_bert_pipeline, get_tags_bert
embedding_model_path = hf_hub_download(
repo_id="facebook/fasttext-en-vectors",
filename="model.bin"
)
embedder = fasttext.load_model(embedding_model_path)
bert = get_bert_pipeline()
id2label = {i: topics[i] for i in range(24)}
lstm = LSTMPipeline(embedder=embedder, id2label=id2label, device=-1)
n_topics = len(topics)
bayes = NaiveBayesMultiClass(topics)
bayes.load('weights/bayes/')
def expand(tags):
with_primary = set()
for i in tags:
with_primary.add(i[:1])
with_primary.add(i)
return sorted(list(with_primary))
def format_as_markdown(predictions: dict) -> str:
if not predictions:
return "_No topics detected._"
grouped = defaultdict(list)
for code, topic in predictions.items():
main = code.split('.')[0]
grouped[main].append((code, topic))
md = "### 📝 Predicted IGCSE Physics Topics\n"
for main_code in sorted(grouped.keys(), key=lambda x: float(x)):
main_title = topics_full.get(main_code, f"{topics_full[main_code]}")
md += f"\n#### {main_code}. {main_title}\n"
subtopics = [st for st in grouped[main_code] if st[0] != main_code]
if subtopics:
for code, name in sorted(subtopics, key=lambda x: [float(n) for n in x[0].split('.')]):
indent = " " * (4 * (code.count('.') - 1))
md += f"{indent}- **{code}**: {name}\n"
else:
md += f"- **{main_code}**: {main_title}\n"
return md
def classify_text(classifier, text, threshold, output_format):
if classifier == 'Transformer':
tags = get_tags_bert(bert, text, threshold)
elif classifier == 'CNN':
tags = []
elif classifier == 'LSTM':
tags = get_tags_lstm(lstm, text, threshold)
else:
tags = get_tags_bayes(bayes, text)
tags = expand(tags)
predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full}
if output_format == "JSON":
return predictions, gr.update(visible=True), gr.update(visible=False)
else:
md = format_as_markdown(predictions)
return {}, gr.update(visible=False), gr.update(value=md, visible=True)
with gr.Blocks(theme="default") as demo:
gr.Markdown("# 🔬 IGCSE Physics Topic Classifier")
gr.Markdown(
"This model classifies IGCSE Physics questions or passages into syllabus topics. "
"Adjust the confidence threshold and choose your preferred output format."
)
with gr.Row(equal_height=True):
# Left column — Input
with gr.Column(scale=1):
classifier = gr.Radio(
["Naïve Bayes", "CNN", "LSTM", "Transformer"],
value="Transformer",
label="Processing model",
info="Choose which model to use to process texts",
)
text_input = gr.Textbox(
lines=8,
placeholder="Enter a physics question or concept...",
label="Input Text",
)
threshold = gr.Slider(0, 1, value=0.5, step=0.05,
label="Confidence Threshold (not available for Naïve Bayes)")
output_format = gr.Radio(
["Markdown", "JSON"],
value="Markdown",
label="Output Format",
info="Choose how to display results",
)
classify_btn = gr.Button("Classify", variant="primary")
# Right column — Output (dynamic)
with gr.Column(scale=1):
json_output = gr.JSON(label="Predicted Topics (JSON)", visible=False)
markdown_output = gr.Markdown(label="Predicted Topics (Markdown)", visible=True)
classify_btn.click(
fn=classify_text,
inputs=[classifier, text_input, threshold, output_format],
outputs=[json_output, json_output, markdown_output],
)
if __name__ == "__main__":
demo.launch(share=True)