ALSv commited on
Commit
a04766c
·
verified ·
1 Parent(s): 0bfd07b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -50
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
@@ -10,6 +9,7 @@ import io
10
 
11
  # ---------------- CONFIG ----------------
12
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
 
13
 
14
  # ---------------- MODEL ----------------
15
  class Classifier(nn.Module):
@@ -30,10 +30,10 @@ class Classifier(nn.Module):
30
  return x
31
 
32
  preprocess = transforms.Compose([
33
- transforms.Resize((224, 224)),
34
  transforms.ToTensor(),
35
  transforms.Normalize(mean=[0.485,0.456,0.406],
36
- std =[0.229,0.224,0.225])
37
  ])
38
 
39
  model = Classifier()
@@ -41,62 +41,82 @@ model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
41
  model.eval()
42
 
43
  # ---------------- FUNZIONE ----------------
44
- def predict(base64_input: str):
45
- if not base64_input:
46
- return "Nessun input fornito", {}
47
-
48
- if base64_input.startswith("data:image"):
49
- base64_input = base64_input.split(",", 1)[1]
50
-
51
- img_bytes = base64.b64decode(base64_input)
52
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
53
-
54
- img_tensor = preprocess(img).unsqueeze(0)
55
- with torch.no_grad():
56
- logits = model(img_tensor)
57
- probs = torch.nn.functional.softmax(logits[0], dim=0)
58
-
59
- probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
60
- max_label = max(probs_dict, key=probs_dict.get)
61
-
62
- return max_label, probs_dict
63
-
64
- # ---------------- HELPER ----------------
65
- def image_to_base64(img: Image.Image):
66
- if img is None:
67
- return ""
68
- buf = io.BytesIO()
69
- img.save(buf, format="JPEG", quality=90)
70
- return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
 
 
 
 
71
 
72
  # ---------------- INTERFACCIA ----------------
73
- with gr.Blocks(title="NSFW Classifier") as demo:
74
- gr.Markdown("## 🎨 NSFW Image Classifier\nCarica un'immagine o incolla la stringa base64.\n\nAPI standard: **/api/predict**")
 
 
 
 
 
 
75
 
76
  with gr.Row():
77
  with gr.Column(scale=2):
 
78
  img_input = gr.Image(label="📷 Carica immagine", type="pil")
79
- base64_input = gr.Textbox(label="📤 Base64 (API)", lines=6)
80
-
 
 
 
81
  with gr.Row():
82
- analyze_btn = gr.Button("✨ Analizza")
83
- clear_btn = gr.Button("🔄 Pulisci")
84
 
85
  with gr.Column(scale=1):
86
- label_output = gr.Textbox(label="Classe predetta")
87
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
88
 
89
- # immagine converte in base64 → textbox
90
- img_input.change(fn=image_to_base64, inputs=img_input, outputs=base64_input)
91
-
92
- # unico endpoint API standard
93
- analyze_btn.click(fn=predict,
94
- inputs=base64_input,
95
- outputs=[label_output, result_display],
96
- api_name="predict")
97
-
98
- clear_btn.click(fn=lambda: "", inputs=None, outputs=base64_input)
99
-
100
- # ---------------- AVVIO ----------------
 
 
 
 
 
 
101
  if __name__ == "__main__":
102
- demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
9
 
10
  # ---------------- CONFIG ----------------
11
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
12
+ theme_color = "#6C5B7B"
13
 
14
  # ---------------- MODEL ----------------
15
  class Classifier(nn.Module):
 
30
  return x
31
 
32
  preprocess = transforms.Compose([
33
+ transforms.Resize((224,224)),
34
  transforms.ToTensor(),
35
  transforms.Normalize(mean=[0.485,0.456,0.406],
36
+ std=[0.229,0.224,0.225])
37
  ])
38
 
39
  model = Classifier()
 
41
  model.eval()
42
 
43
  # ---------------- FUNZIONE ----------------
44
+ def predict(image_input):
45
+ """
46
+ Supporta:
47
+ - PIL Image (UI web)
48
+ - stringa base64 (API)
49
+ """
50
+ try:
51
+ if isinstance(image_input, str):
52
+ if image_input.startswith("data:image"):
53
+ image_input = image_input.split(",",1)[1]
54
+ img_bytes = base64.b64decode(image_input)
55
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
56
+ else:
57
+ img = image_input.convert("RGB")
58
+
59
+ img_tensor = preprocess(img).unsqueeze(0)
60
+
61
+ with torch.no_grad():
62
+ logits = model(img_tensor)
63
+ probs = torch.nn.functional.softmax(logits[0], dim=0)
64
+
65
+ probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
66
+ max_label = max(probs_dict, key=probs_dict.get)
67
+
68
+ return max_label, probs_dict
69
+
70
+ except Exception as e:
71
+ return f"Error: {str(e)}", {}
72
+
73
+ def clear_all():
74
+ return "", ""
75
 
76
  # ---------------- INTERFACCIA ----------------
77
+ with gr.Blocks(title="NSFW Image Classifier") as demo:
78
+
79
+ gr.HTML(f"""
80
+ <div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
81
+ <h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
82
+ <p>Carica un'immagine o incolla la stringa base64 per analizzarla.</p>
83
+ </div>
84
+ """)
85
 
86
  with gr.Row():
87
  with gr.Column(scale=2):
88
+ # Input UI
89
  img_input = gr.Image(label="📷 Carica immagine", type="pil")
90
+ base64_input = gr.Textbox(
91
+ label="📤 Base64 dell'immagine (API)",
92
+ lines=6,
93
+ placeholder="Incolla qui la stringa base64..."
94
+ )
95
  with gr.Row():
96
+ submit_btn = gr.Button("✨ Analizza", variant="primary")
97
+ clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
98
 
99
  with gr.Column(scale=1):
100
+ label_output = gr.Textbox(label="Classe predetta", interactive=False)
101
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
102
 
103
+ # ---------------- Eventi UI ----------------
104
+ submit_btn.click(
105
+ fn=predict,
106
+ inputs=[img_input],
107
+ outputs=[label_output, result_display]
108
+ )
109
+ clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input])
110
+
111
+ # ---------------- Pulsante invisibile per API base64 ----------------
112
+ api_button = gr.Button(visible=False)
113
+ api_button.click(
114
+ fn=predict,
115
+ inputs=[base64_input],
116
+ outputs=[label_output, result_display],
117
+ api_name="predict" # espone /run/predict
118
+ )
119
+
120
+ # ---------------- LAUNCH ----------------
121
  if __name__ == "__main__":
122
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)