Lxz20071231 commited on
Commit
da7367a
·
verified ·
1 Parent(s): 41ea31b

Initial upload

Browse files
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from collections import defaultdict
4
+ import re
5
+ import string
6
+ from textblob import TextBlob
7
+ from bayes import NaiveBayesMultiClass
8
+ import fasttext
9
+ from huggingface_hub import hf_hub_download
10
+ from lstm import LSTMPipeline, LSTMMultiClassClassifier
11
+
12
+ topics = [
13
+ '1.1',
14
+ '1.2',
15
+ '1.3',
16
+ '1.4',
17
+ '1.5',
18
+ '1.6',
19
+ '1.7',
20
+ '1.8',
21
+ '2.1',
22
+ '2.2',
23
+ '2.3',
24
+ '3.1',
25
+ '3.2',
26
+ '3.3',
27
+ '3.4',
28
+ '4.1',
29
+ '4.2',
30
+ '4.3',
31
+ '4.4',
32
+ '4.5',
33
+ '5.1',
34
+ '5.2',
35
+ '6.1',
36
+ '6.2',
37
+ ]
38
+
39
+ topics_full = {
40
+ '1': 'Motion, forces and energy',
41
+ '1.1': 'Physical quantities and measurement techniques',
42
+ '1.2': 'Motion',
43
+ '1.3': 'Mass and weight',
44
+ '1.4': 'Density',
45
+ '1.5': 'Forces',
46
+ '1.6': 'Momentum',
47
+ '1.7': 'Energy, work and power',
48
+ '1.8': 'Pressure',
49
+ '2': 'Thermal physics',
50
+ '2.1': 'Kinetic particle model of matter',
51
+ '2.2': 'Thermal properties and temperature',
52
+ '2.3': 'Transfer of thermal energy',
53
+ '3': 'Waves',
54
+ '3.1': 'General properties of waves',
55
+ '3.2': 'Light',
56
+ '3.3': 'Electromagnetic spectrum',
57
+ '3.4': 'Sound',
58
+ '4': 'Electricity and magnetism',
59
+ '4.1': 'Simple phenomena of magnetism',
60
+ '4.2': 'Electrical quantities',
61
+ '4.3': 'Electric circuits',
62
+ '4.4': 'Electrical safety',
63
+ '4.5': 'Electromagnetic effects',
64
+ '5': 'Nuclear physics',
65
+ '5.1': 'The nuclear model of the atom',
66
+ '5.2': 'Radioactivity',
67
+ '6': 'Space physics',
68
+ '6.1': 'Earth and the Solar System',
69
+ '6.2': 'Stars and the Universe',
70
+ }
71
+
72
+ embedding_model_path = hf_hub_download(
73
+ repo_id="facebook/fasttext-en-vectors",
74
+ filename="model.bin"
75
+ )
76
+ embedder = fasttext.load_model(embedding_model_path)
77
+
78
+ stopword = ["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself",
79
+ "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they",
80
+ "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those",
81
+ "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", "does",
82
+ "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", "at",
83
+ "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above",
84
+ "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then",
85
+ "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most",
86
+ "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t",
87
+ "can", "will", "just", "don", "should", "now"]
88
+ punctuations = string.punctuation
89
+
90
+
91
+ def to_lower(text: str) -> str:
92
+ return text.lower()
93
+
94
+
95
+ def remove_html_tags(text: str) -> str:
96
+ pattern = re.compile('<.*?>')
97
+ return pattern.sub(r'', text)
98
+
99
+
100
+ def remove_punctuations(text: str) -> str:
101
+ return text.translate(str.maketrans('', '', punctuations))
102
+
103
+
104
+ def correct_spellings(text: str) -> str:
105
+ return TextBlob(text).correct().string
106
+
107
+
108
+ def remove_stopwords(text: str) -> str:
109
+ return " ".join([word for word in text.split() if word not in stopword])
110
+
111
+
112
+ def clean(text: str) -> str:
113
+ return remove_stopwords(
114
+ correct_spellings(remove_punctuations(remove_html_tags(to_lower(text))))
115
+ )
116
+
117
+
118
+ bert = pipeline(
119
+ "text-classification",
120
+ model="Lxz20071231/igcse-physics-bert",
121
+ tokenizer="distilbert-base-uncased",
122
+ return_all_scores=True,
123
+ function_to_apply="sigmoid",
124
+ truncation=True
125
+ )
126
+
127
+ id2label = {i: topics[i] for i in range(24)}
128
+
129
+ lstm = LSTMPipeline(embedder=embedder, model=LSTMMultiClassClassifier.from_pretrained(
130
+ "Lxz20071231/igcse-physics-lstm"
131
+ ), id2label=id2label, device=-1)
132
+
133
+ n_topics = len(topics)
134
+
135
+ bayes = NaiveBayesMultiClass(topics)
136
+ bayes.load('bayes/')
137
+
138
+ def get_tags(probs, threshold = 0.5):
139
+ tags = []
140
+ for line in probs:
141
+ found = []
142
+ for p, label in zip(line, topics):
143
+ if p['score'] >= threshold:
144
+ found.append(label)
145
+ tags.append(found)
146
+ return tags
147
+
148
+ def get_tags_multiple_bert(texts, threshold=0.5):
149
+ output = bert(texts)
150
+ return get_tags(output, threshold)
151
+
152
+
153
+ def get_tags_bayes(text):
154
+ return bayes.predict(clean(text), True)
155
+
156
+
157
+ def get_tags_cnn(text, threshold=0.5):
158
+ return []
159
+
160
+
161
+ def get_tags_lstm(text, threshold=0.5):
162
+ return get_tags(lstm(text), threshold)[0]
163
+
164
+
165
+ def get_tags_bert(text, threshold=0.5):
166
+ tags = get_tags_multiple_bert([text], threshold)[0]
167
+ return tags
168
+
169
+ def expand(tags):
170
+ with_primary = set()
171
+ for i in tags:
172
+ with_primary.add(i[:1])
173
+ with_primary.add(i)
174
+ return sorted(list(with_primary))
175
+
176
+
177
+ def format_as_markdown(predictions: dict) -> str:
178
+ if not predictions:
179
+ return "_No topics detected._"
180
+
181
+ grouped = defaultdict(list)
182
+ for code, topic in predictions.items():
183
+ main = code.split('.')[0]
184
+ grouped[main].append((code, topic))
185
+
186
+ md = "### 📝 Predicted IGCSE Physics Topics\n"
187
+ for main_code in sorted(grouped.keys(), key=lambda x: float(x)):
188
+ main_title = topics_full.get(main_code, f"{topics_full[main_code]}")
189
+ md += f"\n#### {main_code}. {main_title}\n"
190
+ subtopics = [st for st in grouped[main_code] if st[0] != main_code]
191
+ if subtopics:
192
+ for code, name in sorted(subtopics, key=lambda x: [float(n) for n in x[0].split('.')]):
193
+ indent = " " * (4 * (code.count('.') - 1))
194
+ md += f"{indent}- **{code}**: {name}\n"
195
+ else:
196
+ md += f"- **{main_code}**: {main_title}\n"
197
+ return md
198
+
199
+
200
+ def classify_text(classifier, text, threshold, output_format):
201
+ if classifier == 'Transformer':
202
+ tags = get_tags_bert(text, threshold)
203
+ elif classifier == 'CNN':
204
+ tags = get_tags_cnn(text, threshold)
205
+ elif classifier == 'LSTM':
206
+ tags = get_tags_lstm(text, threshold)
207
+ else:
208
+ tags = get_tags_bayes(text)
209
+
210
+ tags = expand(tags)
211
+ predictions = {tag: topics_full[tag] for tag in tags if tag in topics_full}
212
+
213
+ if output_format == "JSON":
214
+ return predictions, gr.update(visible=True), gr.update(visible=False)
215
+ else:
216
+ md = format_as_markdown(predictions)
217
+ return {}, gr.update(visible=False), gr.update(value=md, visible=True)
218
+
219
+
220
+ with gr.Blocks(theme="default") as demo:
221
+ gr.Markdown("# 🔬 IGCSE Physics Topic Classifier")
222
+ gr.Markdown(
223
+ "This model classifies IGCSE Physics questions or passages into syllabus topics. "
224
+ "Adjust the confidence threshold and choose your preferred output format."
225
+ )
226
+
227
+ with gr.Row(equal_height=True):
228
+ # Left column — Input
229
+ with gr.Column(scale=1):
230
+ classifier = gr.Radio(
231
+ ["Naïve Bayes", "CNN", "LSTM", "Transformer"],
232
+ value="Transformer",
233
+ label="Processing model",
234
+ info="Choose which model to use to process texts",
235
+ )
236
+ text_input = gr.Textbox(
237
+ lines=8,
238
+ placeholder="Enter a physics question or concept...",
239
+ label="Input Text",
240
+ )
241
+ threshold = gr.Slider(0, 1, value=0.5, step=0.05,
242
+ label="Confidence Threshold (not available for Naïve Bayes)")
243
+ output_format = gr.Radio(
244
+ ["Markdown", "JSON"],
245
+ value="Markdown",
246
+ label="Output Format",
247
+ info="Choose how to display results",
248
+ )
249
+ classify_btn = gr.Button("Classify", variant="primary")
250
+
251
+ # Right column — Output (dynamic)
252
+ with gr.Column(scale=1):
253
+ json_output = gr.JSON(label="Predicted Topics (JSON)", visible=False)
254
+ markdown_output = gr.Markdown(label="Predicted Topics (Markdown)", visible=True)
255
+
256
+ classify_btn.click(
257
+ fn=classify_text,
258
+ inputs=[classifier, text_input, threshold, output_format],
259
+ outputs=[json_output, json_output, markdown_output],
260
+ )
261
+
262
+ if __name__ == "__main__":
263
+ demo.launch(share=True)
bayes.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.feature_extraction.text import CountVectorizer
2
+ from sklearn.naive_bayes import MultinomialNB
3
+ from sklearn.preprocessing import MultiLabelBinarizer
4
+ import typing
5
+ import joblib
6
+
7
+
8
+ class NaiveBayesMultiClass(object):
9
+ def __init__(self, classes: typing.Iterable[str]):
10
+ self.classes = list(classes)
11
+ self.n_classes = len(self.classes)
12
+ self.enc = MultiLabelBinarizer()
13
+ self.enc.fit([classes])
14
+ self.vectorizer = CountVectorizer()
15
+ self.classifiers = []
16
+
17
+ def load(self, path: str):
18
+ self.vectorizer = joblib.load(f'{path}/vectorizer.joblib')
19
+ self.classifiers = [
20
+ joblib.load(f'{path}/class_{i}.joblib') for i in range(self.n_classes)
21
+ ]
22
+
23
+ def predict(self, X: typing.Iterable[str] | str, get_tags=False):
24
+ if type(X) == str:
25
+ return self.predict([X], get_tags)[0]
26
+ x = self.vectorizer.transform(X)
27
+ by_class = [self.classifiers[i].predict(x) for i in range(self.n_classes)]
28
+ ans = []
29
+
30
+ for i in range(len(X)):
31
+ y = []
32
+ for j, cls in enumerate(self.classes):
33
+ if get_tags:
34
+ if by_class[j][i]:
35
+ y.append(cls)
36
+ else:
37
+ y.append(by_class[j][i])
38
+ ans.append(y)
39
+ return ans
40
+
41
+ def __call__(self, *args, **kwargs):
42
+ return self.predict(*args, **kwargs)
43
+
44
+
bayes/class_0.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadf8b28353ce1fc4df5a786d28546d6cd9155f472f23e28ec129ad70bf9b814
3
+ size 106071
bayes/class_1.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:895519139e6ca6ab9597b39202826d91b9b6f75858e502865877dae49ba4bb5a
3
+ size 106071
bayes/class_10.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5efe768fbf47abfcc41456ad387172be57cc58fa2304838c3162fdb679ea27c
3
+ size 106071
bayes/class_11.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0657dd6a86746f734ff16b3cbf50efa400401f5ec1c925dec12f5307adb9c571
3
+ size 106071
bayes/class_12.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47be02b397b9052706209a7c134dd7b789b3e9c389b4fc8ce6ed3b9aa6142cef
3
+ size 106071
bayes/class_13.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44516c260c8687b0650cd8b797d1a332e10b7415ba86aa9b3ffe11880b0366cf
3
+ size 106071
bayes/class_14.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62e4f9daa2c9e0655acf08c53ff7a0ee902fc05f324f9131eebb0eef4d6fc4b2
3
+ size 106071
bayes/class_15.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2860e935c0e907a5a4840d74d28e9b6fc838d1c3d9f9c765b3a8332912b2b3ac
3
+ size 106071
bayes/class_16.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c270988af6b88e68c53ee40fd729c912930f7f58e4a42a51685ad4f4458c59c5
3
+ size 106071
bayes/class_17.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d10b2bf71998e57e27067c346bd366b919baa395750326f580bd11bc424bada7
3
+ size 106071
bayes/class_18.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b00f72de18850b747a2b606f22dc77b9f45df6a0057341e743d8e64fd2edc96
3
+ size 106071
bayes/class_19.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e527c3f9fa747135b8d80befb0fa5bc72f4c5a98466e43d02d75f412e46463cf
3
+ size 106071
bayes/class_2.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4977176edcd2392ea5af799c21adef9a1fde0f1ee42ae31acf598c428c6b5368
3
+ size 106071
bayes/class_20.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc51fcc57f638b99f9abfed2c3136a89aea66cf8e4b235c05eef6d45437bda64
3
+ size 106071
bayes/class_21.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f30b1c7c1ce457f4d2b60a2f1542b2f7b4ee90d2b17abab0360c307929d4fb91
3
+ size 106071
bayes/class_22.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa5690cacca1829378bd4de994ebc24f5f48b37655574a4e390b1d1a4d0cc302
3
+ size 106071
bayes/class_23.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08ade54cff6c56c7c269495d041ae80972977ca0f1c6a95a6a6d40e482f2451a
3
+ size 106071
bayes/class_3.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b08cad3484fa92c1cbd59cdf079ca2bd33876eb5d367b8fa62ecf1c6a76a28e
3
+ size 106071
bayes/class_4.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bdbbd1154b46adf91095e02343f2541ded76ff26005d7e29ee6ab9ca083c70f
3
+ size 106071
bayes/class_5.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c32aab9bf1b26df655a17a0f69681d9782b21fad37309cd757d06754f2ebe07a
3
+ size 106071
bayes/class_6.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdec320a6f796a52bea798e490e7c345a5bd15cc0fd00c6c249441acbf8e0cd7
3
+ size 106071
bayes/class_7.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d0f298d010d3f4aa9651988c3930ec9132c8c6925d092a3d463842262d7f7e1
3
+ size 106071
bayes/class_8.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7627b60bea4ef9eeb433173824ad222a94c847537e141708bc40bf044c6c4ab4
3
+ size 106071
bayes/class_9.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e3245f68cc0e8a2927a44eccaf2a7910290a62ccf9a299a70f271b727fbd329
3
+ size 106071
bayes/vectorizer.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2147176e765ef659fc27a653da99d8c8e07e0a513449e94bd64628b4a27d7cee
3
+ size 40492
lstm.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.utils.rnn import pad_sequence
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils.rnn import pack_padded_sequence
6
+ from types import SimpleNamespace
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+ from transformers import Pipeline
9
+
10
+
11
+ def get_words(model, text: str):
12
+ """
13
+ Break text into tokens using FastText's internal tokenizer.
14
+ """
15
+ lines = [model.get_line(line)[0] for line in text.split("\n")]
16
+ words = []
17
+ for line in lines:
18
+ for w in line[:-1]:
19
+ words.append(w)
20
+ return words
21
+
22
+
23
+ def get_vectors(model, text: str):
24
+ """
25
+ Convert text → list of embedding vectors.
26
+ """
27
+ words = get_words(model, text)
28
+ vectors = [model[w] for w in words]
29
+ return vectors
30
+
31
+
32
+ def get_tensor(model, text: str):
33
+ """
34
+ Convert text → (seq_len, embedding_dim) tensor
35
+ """
36
+ vectors = get_vectors(model, text)
37
+ if len(vectors) == 0:
38
+ # fallback for empty text
39
+ return torch.zeros(1, model.get_dimension())
40
+ return torch.tensor(vectors, dtype=torch.float)
41
+
42
+
43
+ def preprocess_batch(embedder, texts):
44
+ """
45
+ Convert a list of text strings into:
46
+ x_padded: (batch, seq_len, emb_dim)
47
+ lengths: (batch,)
48
+ Both sorted by sequence length (DESC) for pack_padded_sequence.
49
+ """
50
+
51
+ # Convert each text → tensor
52
+ seq_tensors = [get_tensor(embedder, t) for t in texts]
53
+
54
+ # Compute lengths BEFORE padding
55
+ lengths = torch.tensor([seq.size(0) for seq in seq_tensors], dtype=torch.long)
56
+
57
+ # Sort by length (DESC)
58
+ lengths_sorted, sort_idx = torch.sort(lengths, descending=True)
59
+ seq_tensors = [seq_tensors[i] for i in sort_idx]
60
+
61
+ # Pad to create (batch, max_seq_len, emb_dim)
62
+ x_padded = pad_sequence(seq_tensors, batch_first=True)
63
+
64
+ return x_padded, lengths_sorted
65
+
66
+
67
+ class LSTMMultiClassClassifier(nn.Module, PyTorchModelHubMixin):
68
+ def __init__(self, embedding_dim, hidden_dim, num_classes,
69
+ num_layers=1, bidirectional=True, dropout=0.5, **kwargs):
70
+ super().__init__()
71
+
72
+ # REQUIRED for HuggingFace Pipeline
73
+ self.device = torch.device("cpu")
74
+
75
+ # Save config
76
+ self.config = SimpleNamespace(
77
+ embedding_dim=embedding_dim,
78
+ hidden_dim=hidden_dim,
79
+ num_classes=num_classes,
80
+ num_layers=num_layers,
81
+ bidirectional=bidirectional,
82
+ dropout=dropout
83
+ )
84
+
85
+ self.embedding_dim = embedding_dim
86
+ self.hidden_dim = hidden_dim
87
+ self.num_layers = num_layers
88
+ self.bidirectional = bidirectional
89
+ self.dropout = dropout
90
+ self.num_classes = num_classes
91
+
92
+ self.lstm = nn.LSTM(
93
+ input_size=embedding_dim,
94
+ hidden_size=hidden_dim,
95
+ num_layers=num_layers,
96
+ batch_first=True,
97
+ dropout=dropout if num_layers > 1 else 0,
98
+ bidirectional=bidirectional
99
+ )
100
+
101
+ direction = 2 if bidirectional else 1
102
+ self.fc = nn.Sequential(
103
+ nn.Linear(hidden_dim * direction, 128),
104
+ nn.ReLU(),
105
+ nn.Linear(128, 128),
106
+ nn.ReLU(),
107
+ nn.Linear(128, num_classes)
108
+ )
109
+
110
+ @classmethod
111
+ def from_config(cls, config):
112
+ return cls(
113
+ embedding_dim=config.embedding_dim,
114
+ hidden_dim=config.hidden_dim,
115
+ num_classes=config.num_classes,
116
+ num_layers=config.num_layers,
117
+ bidirectional=config.bidirectional,
118
+ dropout=config.dropout
119
+ )
120
+
121
+ # REQUIRED for Transformers Pipeline (updates internal device)
122
+ def to(self, device):
123
+ super().to(device)
124
+ self.device = device
125
+ return self
126
+
127
+ def forward(self, x, lengths):
128
+ x = x.to(self.device)
129
+ lengths = lengths.to(self.device)
130
+
131
+ packed = pack_padded_sequence(
132
+ x, lengths.cpu(), batch_first=True, enforce_sorted=True
133
+ )
134
+ _, (h_n, _) = self.lstm(packed)
135
+
136
+ if self.bidirectional:
137
+ h = torch.cat((h_n[-2], h_n[-1]), dim=1)
138
+ else:
139
+ h = h_n[-1]
140
+
141
+ return self.fc(h)
142
+
143
+
144
+ class LSTMPipeline(Pipeline):
145
+ def __init__(self, id2label, embedder, **kwargs):
146
+ model = LSTMMultiClassClassifier.from_pretrained(
147
+ "Lxz20071231/igcse-physics-lstm"
148
+ )
149
+ super().__init__(model=model, tokenizer=None, **kwargs)
150
+ self.id2label = id2label
151
+ self.embedder = embedder
152
+
153
+ def preprocess(self, inputs):
154
+ if isinstance(inputs, str):
155
+ texts = [inputs]
156
+ else:
157
+ texts = list(inputs)
158
+ x, lengths = preprocess_batch(self.embedder, texts)
159
+ return {"x": x, "lengths": lengths}
160
+
161
+ def _forward(self, model_inputs):
162
+ x = model_inputs["x"]
163
+ lengths = model_inputs["lengths"]
164
+ with torch.no_grad():
165
+ logits = self.model(x, lengths)
166
+ return logits
167
+
168
+ def postprocess(self, logits):
169
+ probs = F.sigmoid(logits)
170
+
171
+ return probs
172
+
173
+ def _sanitize_parameters(self, **kwargs):
174
+ return {}, {}, {}
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ scikit-learn
3
+ nltk
4
+ textblob
5
+ fasttext
6
+ transformers>=4.46.0
7
+ huggingface-hub<1.0.0
8
+ torch>=2.1