File size: 4,279 Bytes
da7367a
 
 
1edca62
 
 
 
 
 
 
da7367a
 
 
 
 
 
 
1edca62
da7367a
 
 
96baabe
da7367a
 
 
 
1edca62
da7367a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1edca62
da7367a
1edca62
da7367a
1edca62
da7367a
1edca62
da7367a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import fasttext
from huggingface_hub import hf_hub_download
from collections import defaultdict

from utils.consts import topics, topics_full

from models.bayes import NaiveBayesMultiClass, get_tags_bayes
from models.lstm import LSTMPipeline, get_tags_lstm
from models.bert import get_bert_pipeline, get_tags_bert

embedding_model_path = hf_hub_download(
    repo_id="facebook/fasttext-en-vectors",
    filename="model.bin"
)
embedder = fasttext.load_model(embedding_model_path)

bert = get_bert_pipeline()

id2label = {i: topics[i] for i in range(24)}

lstm = LSTMPipeline(embedder=embedder, id2label=id2label, device=-1)

n_topics = len(topics)

bayes = NaiveBayesMultiClass(topics)
bayes.load('weights/bayes/')


def expand(tags):
    with_primary = set()
    for i in tags:
        with_primary.add(i[:1])
        with_primary.add(i)
    return sorted(list(with_primary))


def format_as_markdown(predictions: dict) -> str:
    if not predictions:
        return "_No topics detected._"

    grouped = defaultdict(list)
    for code, topic in predictions.items():
        main = code.split('.')[0]
        grouped[main].append((code, topic))

    md = "### 📝 Predicted IGCSE Physics Topics\n"
    for main_code in sorted(grouped.keys(), key=lambda x: float(x)):
        main_title = topics_full.get(main_code, f"{topics_full[main_code]}")
        md += f"\n#### {main_code}. {main_title}\n"
        subtopics = [st for st in grouped[main_code] if st[0] != main_code]
        if subtopics:
            for code, name in sorted(subtopics, key=lambda x: [float(n) for n in x[0].split('.')]):
                indent = " " * (4 * (code.count('.') - 1))
                md += f"{indent}- **{code}**: {name}\n"
        else:
            md += f"- **{main_code}**: {main_title}\n"
    return md


def classify_text(classifier, text, threshold, output_format):
    if classifier == 'Transformer':
        tags = get_tags_bert(bert, text, threshold)
    elif classifier == 'CNN':
        tags = []
    elif classifier == 'LSTM':
        tags = get_tags_lstm(lstm, text, threshold)
    else:
        tags = get_tags_bayes(bayes, text)

    tags = expand(tags)
    predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full}

    if output_format == "JSON":
        return predictions, gr.update(visible=True), gr.update(visible=False)
    else:
        md = format_as_markdown(predictions)
        return {}, gr.update(visible=False), gr.update(value=md, visible=True)


with gr.Blocks(theme="default") as demo:
    gr.Markdown("# 🔬 IGCSE Physics Topic Classifier")
    gr.Markdown(
        "This model classifies IGCSE Physics questions or passages into syllabus topics. "
        "Adjust the confidence threshold and choose your preferred output format."
    )

    with gr.Row(equal_height=True):
        # Left column — Input
        with gr.Column(scale=1):
            classifier = gr.Radio(
                ["Naïve Bayes", "CNN", "LSTM", "Transformer"],
                value="Transformer",
                label="Processing model",
                info="Choose which model to use to process texts",
            )
            text_input = gr.Textbox(
                lines=8,
                placeholder="Enter a physics question or concept...",
                label="Input Text",
            )
            threshold = gr.Slider(0, 1, value=0.5, step=0.05,
                                  label="Confidence Threshold (not available for Naïve Bayes)")
            output_format = gr.Radio(
                ["Markdown", "JSON"],
                value="Markdown",
                label="Output Format",
                info="Choose how to display results",
            )
            classify_btn = gr.Button("Classify", variant="primary")

        # Right column — Output (dynamic)
        with gr.Column(scale=1):
            json_output = gr.JSON(label="Predicted Topics (JSON)", visible=False)
            markdown_output = gr.Markdown(label="Predicted Topics (Markdown)", visible=True)

    classify_btn.click(
        fn=classify_text,
        inputs=[classifier, text_input, threshold, output_format],
        outputs=[json_output, json_output, markdown_output],
    )

if __name__ == "__main__":
    demo.launch(share=True)