Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import gradio as gr | |
| import fasttext | |
| from huggingface_hub import hf_hub_download | |
| from collections import defaultdict | |
| from utils.consts import topics, topics_full | |
| from models.bayes import NaiveBayesMultiClass, get_tags_bayes | |
| from models.lstm import LSTMPipeline, get_tags_lstm | |
| from models.bert import get_bert_pipeline, get_tags_bert | |
| embedding_model_path = hf_hub_download( | |
| repo_id="facebook/fasttext-en-vectors", | |
| filename="model.bin" | |
| ) | |
| embedder = fasttext.load_model(embedding_model_path) | |
| bert = get_bert_pipeline() | |
| id2label = {i: topics[i] for i in range(24)} | |
| lstm = LSTMPipeline(embedder=embedder, id2label=id2label, device=-1) | |
| n_topics = len(topics) | |
| bayes = NaiveBayesMultiClass(topics) | |
| bayes.load('weights/bayes/') | |
| def expand(tags): | |
| with_primary = set() | |
| for i in tags: | |
| with_primary.add(i[:1]) | |
| with_primary.add(i) | |
| return sorted(list(with_primary)) | |
| def format_as_markdown(predictions: dict) -> str: | |
| if not predictions: | |
| return "_No topics detected._" | |
| grouped = defaultdict(list) | |
| for code, topic in predictions.items(): | |
| main = code.split('.')[0] | |
| grouped[main].append((code, topic)) | |
| md = "### 📝 Predicted IGCSE Physics Topics\n" | |
| for main_code in sorted(grouped.keys(), key=lambda x: float(x)): | |
| main_title = topics_full.get(main_code, f"{topics_full[main_code]}") | |
| md += f"\n#### {main_code}. {main_title}\n" | |
| subtopics = [st for st in grouped[main_code] if st[0] != main_code] | |
| if subtopics: | |
| for code, name in sorted(subtopics, key=lambda x: [float(n) for n in x[0].split('.')]): | |
| indent = " " * (4 * (code.count('.') - 1)) | |
| md += f"{indent}- **{code}**: {name}\n" | |
| else: | |
| md += f"- **{main_code}**: {main_title}\n" | |
| return md | |
| def classify_text(classifier, text, threshold, output_format): | |
| if classifier == 'Transformer': | |
| tags = get_tags_bert(bert, text, threshold) | |
| elif classifier == 'CNN': | |
| tags = [] | |
| elif classifier == 'LSTM': | |
| tags = get_tags_lstm(lstm, text, threshold) | |
| else: | |
| tags = get_tags_bayes(bayes, text) | |
| tags = expand(tags) | |
| predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full} | |
| if output_format == "JSON": | |
| return predictions, gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| md = format_as_markdown(predictions) | |
| return {}, gr.update(visible=False), gr.update(value=md, visible=True) | |
| with gr.Blocks(theme="default") as demo: | |
| gr.Markdown("# 🔬 IGCSE Physics Topic Classifier") | |
| gr.Markdown( | |
| "This model classifies IGCSE Physics questions or passages into syllabus topics. " | |
| "Adjust the confidence threshold and choose your preferred output format." | |
| ) | |
| with gr.Row(equal_height=True): | |
| # Left column — Input | |
| with gr.Column(scale=1): | |
| classifier = gr.Radio( | |
| ["Naïve Bayes", "CNN", "LSTM", "Transformer"], | |
| value="Transformer", | |
| label="Processing model", | |
| info="Choose which model to use to process texts", | |
| ) | |
| text_input = gr.Textbox( | |
| lines=8, | |
| placeholder="Enter a physics question or concept...", | |
| label="Input Text", | |
| ) | |
| threshold = gr.Slider(0, 1, value=0.5, step=0.05, | |
| label="Confidence Threshold (not available for Naïve Bayes)") | |
| output_format = gr.Radio( | |
| ["Markdown", "JSON"], | |
| value="Markdown", | |
| label="Output Format", | |
| info="Choose how to display results", | |
| ) | |
| classify_btn = gr.Button("Classify", variant="primary") | |
| # Right column — Output (dynamic) | |
| with gr.Column(scale=1): | |
| json_output = gr.JSON(label="Predicted Topics (JSON)", visible=False) | |
| markdown_output = gr.Markdown(label="Predicted Topics (Markdown)", visible=True) | |
| classify_btn.click( | |
| fn=classify_text, | |
| inputs=[classifier, text_input, threshold, output_format], | |
| outputs=[json_output, json_output, markdown_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |