enigmaize commited on
Commit
675f8ad
·
verified ·
1 Parent(s): 5e2ba76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -17
app.py CHANGED
@@ -2,8 +2,11 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from scipy.special import softmax
3
  import gradio as gr
4
  import torch
 
 
 
5
 
6
- model_name = "enigmaize/arxiv-nlp_project-scibert"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
@@ -11,6 +14,10 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
  labels = ['math.AC', 'cs.CV', 'cs.AI', 'cs.SY', 'math.GR', 'cs.CE', 'cs.PL', 'cs.IT', 'cs.DS', 'cs.NE', 'math.ST']
12
 
13
  def classify_text(text):
 
 
 
 
14
  # Токенизация текста
15
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
16
 
@@ -28,27 +35,107 @@ def classify_text(text):
28
  # Сортировка по вероятности (по убыванию)
29
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
30
 
31
- return sorted_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Описание интерфейса
34
- description = "Enter the abstract of a scientific paper, and the model will predict its arXiv category."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Создание интерфейса Gradio
37
  interface = gr.Interface(
38
- fn=classify_text, # Функция, которая будет вызываться
39
- inputs=gr.Textbox(lines=10, placeholder="Paste abstract here...", label="Paper Abstract"), # Вход: текстовое поле
40
- outputs=gr.Label(num_top_classes=3, label="Predicted Categories"), # Выход: метки с вероятностями
41
- title="ArXiv Paper Classifier (SciBERT)",
42
- description=description,
 
 
 
 
 
 
 
 
 
43
  examples=[
44
- ["We propose a novel deep learning approach for image recognition using convolutional neural networks."],
45
- ["We analyze the computational complexity of algorithms for sorting and searching."],
46
- ["This paper presents a statistical method for analyzing the spread of infectious diseases in populations."]
47
- ] # Примеры для удобства
 
 
 
 
 
 
 
48
  )
49
 
50
- # Запуск интерфейса (это нужно для локального запуска, не для Spaces)
51
- # interface.launch()
52
-
53
- # Для Hugging Face Spaces, просто укажите интерфейс
54
  interface.launch()
 
2
  from scipy.special import softmax
3
  import gradio as gr
4
  import torch
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+ import base64
8
 
9
+ model_name = "enigmaize/arxiv-nlp_project-scibert"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
 
 
14
  labels = ['math.AC', 'cs.CV', 'cs.AI', 'cs.SY', 'math.GR', 'cs.CE', 'cs.PL', 'cs.IT', 'cs.DS', 'cs.NE', 'math.ST']
15
 
16
  def classify_text(text):
17
+ if not text.strip():
18
+ # Возвращаем пустой результат, если текст пустой
19
+ return {label: 0.0 for label in labels}, None
20
+
21
  # Токенизация текста
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
23
 
 
35
  # Сортировка по вероятности (по убыванию)
36
  sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
37
 
38
+ # --- Создание диаграммы ---
39
+ top_k = 5
40
+ top_labels = list(sorted_results.keys())[:top_k]
41
+ top_probs = list(sorted_results.values())[:top_k]
42
+
43
+ fig, ax = plt.subplots(figsize=(8, 4))
44
+ bars = ax.barh(top_labels, top_probs, color=['#4c72b0', '#dd8452', '#55a868', '#c44e52', '#8172b3'])
45
+ ax.set_xlabel('Probability')
46
+ ax.set_title('Top 5 Predicted Categories')
47
+ ax.set_xlim(0, 1)
48
+
49
+ # Добавление числовых значений на барах
50
+ for bar, prob in zip(bars, top_probs):
51
+ width = bar.get_width()
52
+ ax.text(width, bar.get_y() + bar.get_height()/2, f'{prob:.3f}',
53
+ va='center', ha='left', fontsize=10)
54
+
55
+ plt.tight_layout()
56
+
57
+ # Сохраняем диаграмму в буфер
58
+ buf = io.BytesIO()
59
+ plt.savefig(buf, format='png')
60
+ buf.seek(0)
61
+ img_base64 = base64.b64encode(buf.read()).decode('utf-8')
62
+ plt.close(fig) # Закрываем фигуру, чтобы освободить память
63
 
64
+ chart_html = f'<img src="data:image/png;base64,{img_base64}" alt="Prediction Chart" style="width:100%;">'
65
+
66
+ return sorted_results, chart_html
67
+
68
+ # --- HTML для кастомного стиля ---
69
+ custom_css = """
70
+ body {
71
+ background-color: #f0f4f8;
72
+ }
73
+ .gradio-container {
74
+ max-width: 900px;
75
+ margin: auto;
76
+ padding-top: 20px;
77
+ padding-bottom: 20px;
78
+ background: white;
79
+ border-radius: 10px;
80
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
81
+ }
82
+ h1 {
83
+ color: #2c3e50;
84
+ text-align: center;
85
+ font-family: 'Arial', sans-serif;
86
+ }
87
+ h3 {
88
+ color: #34495e;
89
+ }
90
+ label {
91
+ font-weight: bold;
92
+ color: #2c3e50;
93
+ }
94
+ """
95
+
96
+ # --- HTML для информации о модели ---
97
+ model_info_html = """
98
+ <div style="background-color: #ecf0f1; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
99
+ <h3>About the Model</h3>
100
+ <p>This classifier uses a <strong>SciBERT</strong> model fine-tuned on the <a href="https://huggingface.co/datasets/ccdv/arxiv-classification" target="_blank">arXiv Classification dataset</a>.</p>
101
+ <p>It predicts one of 11 categories related to Computer Science and Mathematics.</p>
102
+ <p>For best results, input the abstract of a scientific paper.</p>
103
+ </div>
104
+ """
105
+
106
+ # --- HTML для описания ---
107
+ description_html = """
108
+ <p style="font-size: 1.1em; text-align: center;">Enter the abstract of a scientific paper below, and the model will predict its arXiv category.</p>
109
+ """
110
 
111
  # Создание интерфейса Gradio
112
  interface = gr.Interface(
113
+ fn=classify_text,
114
+ inputs=gr.Textbox(
115
+ lines=10,
116
+ placeholder="Paste the abstract of a scientific paper here...",
117
+ label="Paper Abstract",
118
+ elem_classes="textbox_custom"
119
+ ),
120
+ outputs=[
121
+ gr.Label(num_top_classes=5, label="Prediction Probabilities"),
122
+ gr.HTML(label="Prediction Chart")
123
+ ],
124
+ title="🔬 ArXiv Paper Classifier (SciBERT)",
125
+ description=description_html,
126
+ article=model_info_html,
127
  examples=[
128
+ [
129
+ "We propose a novel deep learning approach for image recognition using convolutional neural networks. Our method achieves state-of-the-art performance on the ImageNet benchmark, surpassing previous results by a significant margin through architectural innovations and improved training procedures."
130
+ ],
131
+ [
132
+ "We analyze the computational complexity of algorithms for sorting and searching. Specifically, we present a new variant of merge sort that reduces the number of comparisons in the average case. We also discuss the implications for cache performance and practical implementations."
133
+ ],
134
+ [
135
+ "This paper presents a statistical method for analyzing the spread of infectious diseases in populations. Using a modified SIR model with time-dependent transmission rates, we simulate the effects of various intervention strategies on disease dynamics."
136
+ ]
137
+ ],
138
+ css=custom_css
139
  )
140
 
 
 
 
 
141
  interface.launch()